Skip to content
Snippets Groups Projects
Commit dcb896fd authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-11712][ML] Make spark.ml LDAModel be abstract

Per discussion in the initial Pipelines LDA PR [https://github.com/apache/spark/pull/9513], we should make LDAModel abstract and create a LocalLDAModel. This code simplification should be done before the 1.6 release to ensure API compatibility in future releases.

CC feynmanliang mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9678 from jkbradley/lda-pipelines-2.
parent bc092966
No related branches found
No related tags found
No related merge requests found
......@@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* Model fitted by [[LDA]].
*
* @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
* @param oldLocalModel Underlying spark.mllib model.
* If this model was produced by Online LDA, then this is the
* only model representation.
* If this model was produced by EM, then this local
* representation may be built lazily.
* @param sqlContext Used to construct local DataFrames for returning query results
*/
@Since("1.6.0")
@Experimental
class LDAModel private[ml] (
sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
@Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel],
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging {
/** Returns underlying spark.mllib model */
// NOTE to developers:
// This abstraction should contain all important functionality for basic LDA usage.
// Specializations of this class can contain expert-only functionality.
/**
* Underlying spark.mllib model.
* If this model was produced by Online LDA, then this is the only model representation.
* If this model was produced by EM, then this local representation may be built lazily.
*/
@Since("1.6.0")
protected def getModel: OldLDAModel = oldLocalModel match {
case Some(m) => m
case None =>
// Should never happen.
throw new RuntimeException("LDAModel required local model format," +
" but the underlying model is missing.")
}
protected def oldLocalModel: OldLocalLDAModel
/** Returns underlying spark.mllib model, which may be local or distributed */
@Since("1.6.0")
protected def getModel: OldLDAModel
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
......@@ -352,16 +352,17 @@ class LDAModel private[ml] (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)
@Since("1.6.0")
override def copy(extra: ParamMap): LDAModel = {
val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
copyValues(copied, extra).setParent(parent)
}
/**
* Transforms the input dataset.
*
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*/
@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext))
val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
......@@ -388,56 +389,50 @@ class LDAModel private[ml] (
* This is a matrix of size vocabSize x k, where each column is a topic.
* No guarantees are given about the ordering of the topics.
*
* WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM,
* then this method could involve collecting a large amount of data to the driver
* (on the order of vocabSize x k).
* WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by
* the Expectation-Maximization ("em") [[optimizer]], then this method could involve
* collecting a large amount of data to the driver (on the order of vocabSize x k).
*/
@Since("1.6.0")
def topicsMatrix: Matrix = getModel.topicsMatrix
def topicsMatrix: Matrix = oldLocalModel.topicsMatrix
/** Indicates whether this instance is of type [[DistributedLDAModel]] */
@Since("1.6.0")
def isDistributed: Boolean = false
def isDistributed: Boolean
/**
* Calculates a lower bound on the log likelihood of the entire corpus.
*
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
* WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting
* a large [[topicsMatrix]] to the driver. This implementation may be changed in the
* future.
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus
*/
@Since("1.6.0")
def logLikelihood(dataset: DataFrame): Double = oldLocalModel match {
case Some(m) =>
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
m.logLikelihood(oldDataset)
case None =>
// Should never happen.
throw new RuntimeException("LocalLDAModel.logLikelihood was called," +
" but the underlying model is missing.")
def logLikelihood(dataset: DataFrame): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logLikelihood(oldDataset)
}
/**
* Calculate an upper bound bound on perplexity. (Lower is better.)
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token.
*/
@Since("1.6.0")
def logPerplexity(dataset: DataFrame): Double = oldLocalModel match {
case Some(m) =>
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
m.logPerplexity(oldDataset)
case None =>
// Should never happen.
throw new RuntimeException("LocalLDAModel.logPerplexity was called," +
" but the underlying model is missing.")
def logPerplexity(dataset: DataFrame): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logPerplexity(oldDataset)
}
/**
......@@ -468,10 +463,43 @@ class LDAModel private[ml] (
/**
* :: Experimental ::
*
* Distributed model fitted by [[LDA]] using Expectation-Maximization (EM).
* Local (non-distributed) model fitted by [[LDA]].
*
* This model stores the inferred topics only; it does not store info about the training dataset.
*/
@Since("1.6.0")
@Experimental
class LocalLDAModel private[ml] (
uid: String,
vocabSize: Int,
@Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
sqlContext: SQLContext)
extends LDAModel(uid, vocabSize, sqlContext) {
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
}
override protected def getModel: OldLDAModel = oldLocalModel
@Since("1.6.0")
override def isDistributed: Boolean = false
}
/**
* :: Experimental ::
*
* Distributed model fitted by [[LDA]].
* This type of model is currently only produced by Expectation-Maximization (EM).
*
* This model stores the inferred topics, the full training dataset, and the topic distribution
* for each training document.
*
* @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping
* [[copy()]] cheap.
*/
@Since("1.6.0")
@Experimental
......@@ -479,59 +507,39 @@ class DistributedLDAModel private[ml] (
uid: String,
vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel,
sqlContext: SQLContext)
extends LDAModel(uid, vocabSize, None, sqlContext) {
sqlContext: SQLContext,
private var oldLocalModelOption: Option[OldLocalLDAModel])
extends LDAModel(uid, vocabSize, sqlContext) {
override protected def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
oldLocalModelOption = Some(oldDistributedModel.toLocal)
}
oldLocalModelOption.get
}
override protected def getModel: OldLDAModel = oldDistributedModel
/**
* Convert this distributed model to a local representation. This discards info about the
* training dataset.
*
* WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/
@Since("1.6.0")
def toLocal: LDAModel = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
}
@Since("1.6.0")
override protected def getModel: OldLDAModel = oldDistributedModel
def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
@Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = {
val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext)
if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel
val copied =
new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
copyValues(copied, extra).setParent(parent)
copied
}
@Since("1.6.0")
override def topicsMatrix: Matrix = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.topicsMatrix
}
@Since("1.6.0")
override def isDistributed: Boolean = true
@Since("1.6.0")
override def logLikelihood(dataset: DataFrame): Double = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.logLikelihood(dataset)
}
@Since("1.6.0")
override def logPerplexity(dataset: DataFrame): Double = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.logPerplexity(dataset)
}
/**
* Log likelihood of the observed tokens in the training set,
* given the current parameter estimates:
......@@ -673,9 +681,9 @@ class LDA @Since("1.6.0") (
val oldModel = oldLDA.run(oldData)
val newModel = oldModel match {
case m: OldLocalLDAModel =>
new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext)
new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
case m: OldDistributedLDAModel =>
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
}
copyValues(newModel).setParent(this)
}
......
......@@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
MLTestingUtils.checkCopy(model)
assert(!model.isInstanceOf[DistributedLDAModel])
assert(model.isInstanceOf[LocalLDAModel])
assert(model.vocabSize === vocabSize)
assert(model.estimatedDocConcentration.size === k)
assert(model.topicsMatrix.numRows === vocabSize)
......@@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.isDistributed)
val localModel = model.toLocal
assert(!localModel.isInstanceOf[DistributedLDAModel])
assert(localModel.isInstanceOf[LocalLDAModel])
// training logLikelihood, logPrior
val ll = model.trainingLogLikelihood
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment