From dcb896fd8cec83483f700ee985c352be61cdf233 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph@databricks.com> Date: Thu, 12 Nov 2015 17:03:19 -0800 Subject: [PATCH] [SPARK-11712][ML] Make spark.ml LDAModel be abstract Per discussion in the initial Pipelines LDA PR [https://github.com/apache/spark/pull/9513], we should make LDAModel abstract and create a LocalLDAModel. This code simplification should be done before the 1.6 release to ensure API compatibility in future releases. CC feynmanliang mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9678 from jkbradley/lda-pipelines-2. --- .../org/apache/spark/ml/clustering/LDA.scala | 180 +++++++++--------- .../apache/spark/ml/clustering/LDASuite.scala | 4 +- 2 files changed, 96 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index f66233ed3d..92e05815d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * Model fitted by [[LDA]]. * * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) - * @param oldLocalModel Underlying spark.mllib model. - * If this model was produced by Online LDA, then this is the - * only model representation. - * If this model was produced by EM, then this local - * representation may be built lazily. * @param sqlContext Used to construct local DataFrames for returning query results */ @Since("1.6.0") @Experimental -class LDAModel private[ml] ( +sealed abstract class LDAModel private[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, - @Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel], @Since("1.6.0") @transient protected val sqlContext: SQLContext) extends Model[LDAModel] with LDAParams with Logging { - /** Returns underlying spark.mllib model */ + // NOTE to developers: + // This abstraction should contain all important functionality for basic LDA usage. + // Specializations of this class can contain expert-only functionality. + + /** + * Underlying spark.mllib model. + * If this model was produced by Online LDA, then this is the only model representation. + * If this model was produced by EM, then this local representation may be built lazily. + */ @Since("1.6.0") - protected def getModel: OldLDAModel = oldLocalModel match { - case Some(m) => m - case None => - // Should never happen. - throw new RuntimeException("LDAModel required local model format," + - " but the underlying model is missing.") - } + protected def oldLocalModel: OldLocalLDAModel + + /** Returns underlying spark.mllib model, which may be local or distributed */ + @Since("1.6.0") + protected def getModel: OldLDAModel /** * The features for LDA should be a [[Vector]] representing the word counts in a document. @@ -352,16 +352,17 @@ class LDAModel private[ml] ( @Since("1.6.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.6.0") - override def copy(extra: ParamMap): LDAModel = { - val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) - copyValues(copied, extra).setParent(parent) - } - + /** + * Transforms the input dataset. + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + */ @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { if ($(topicDistributionCol).nonEmpty) { - val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext)) + val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + @@ -388,56 +389,50 @@ class LDAModel private[ml] ( * This is a matrix of size vocabSize x k, where each column is a topic. * No guarantees are given about the ordering of the topics. * - * WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM, - * then this method could involve collecting a large amount of data to the driver - * (on the order of vocabSize x k). + * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by + * the Expectation-Maximization ("em") [[optimizer]], then this method could involve + * collecting a large amount of data to the driver (on the order of vocabSize x k). */ @Since("1.6.0") - def topicsMatrix: Matrix = getModel.topicsMatrix + def topicsMatrix: Matrix = oldLocalModel.topicsMatrix /** Indicates whether this instance is of type [[DistributedLDAModel]] */ @Since("1.6.0") - def isDistributed: Boolean = false + def isDistributed: Boolean /** * Calculates a lower bound on the log likelihood of the entire corpus. * * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). * - * WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting - * a large [[topicsMatrix]] to the driver. This implementation may be changed in the - * future. + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. * * @param dataset test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus */ @Since("1.6.0") - def logLikelihood(dataset: DataFrame): Double = oldLocalModel match { - case Some(m) => - val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) - m.logLikelihood(oldDataset) - case None => - // Should never happen. - throw new RuntimeException("LocalLDAModel.logLikelihood was called," + - " but the underlying model is missing.") + def logLikelihood(dataset: DataFrame): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logLikelihood(oldDataset) } /** * Calculate an upper bound bound on perplexity. (Lower is better.) * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * * @param dataset test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. */ @Since("1.6.0") - def logPerplexity(dataset: DataFrame): Double = oldLocalModel match { - case Some(m) => - val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) - m.logPerplexity(oldDataset) - case None => - // Should never happen. - throw new RuntimeException("LocalLDAModel.logPerplexity was called," + - " but the underlying model is missing.") + def logPerplexity(dataset: DataFrame): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logPerplexity(oldDataset) } /** @@ -468,10 +463,43 @@ class LDAModel private[ml] ( /** * :: Experimental :: * - * Distributed model fitted by [[LDA]] using Expectation-Maximization (EM). + * Local (non-distributed) model fitted by [[LDA]]. + * + * This model stores the inferred topics only; it does not store info about the training dataset. + */ +@Since("1.6.0") +@Experimental +class LocalLDAModel private[ml] ( + uid: String, + vocabSize: Int, + @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel, + sqlContext: SQLContext) + extends LDAModel(uid, vocabSize, sqlContext) { + + @Since("1.6.0") + override def copy(extra: ParamMap): LocalLDAModel = { + val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel] + } + + override protected def getModel: OldLDAModel = oldLocalModel + + @Since("1.6.0") + override def isDistributed: Boolean = false +} + + +/** + * :: Experimental :: + * + * Distributed model fitted by [[LDA]]. + * This type of model is currently only produced by Expectation-Maximization (EM). * * This model stores the inferred topics, the full training dataset, and the topic distribution * for each training document. + * + * @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping + * [[copy()]] cheap. */ @Since("1.6.0") @Experimental @@ -479,59 +507,39 @@ class DistributedLDAModel private[ml] ( uid: String, vocabSize: Int, private val oldDistributedModel: OldDistributedLDAModel, - sqlContext: SQLContext) - extends LDAModel(uid, vocabSize, None, sqlContext) { + sqlContext: SQLContext, + private var oldLocalModelOption: Option[OldLocalLDAModel]) + extends LDAModel(uid, vocabSize, sqlContext) { + + override protected def oldLocalModel: OldLocalLDAModel = { + if (oldLocalModelOption.isEmpty) { + oldLocalModelOption = Some(oldDistributedModel.toLocal) + } + oldLocalModelOption.get + } + + override protected def getModel: OldLDAModel = oldDistributedModel /** * Convert this distributed model to a local representation. This discards info about the * training dataset. + * + * WARNING: This involves collecting a large [[topicsMatrix]] to the driver. */ @Since("1.6.0") - def toLocal: LDAModel = { - if (oldLocalModel.isEmpty) { - oldLocalModel = Some(oldDistributedModel.toLocal) - } - new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) - } - - @Since("1.6.0") - override protected def getModel: OldLDAModel = oldDistributedModel + def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) @Since("1.6.0") override def copy(extra: ParamMap): DistributedLDAModel = { - val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext) - if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel + val copied = + new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption) copyValues(copied, extra).setParent(parent) copied } - @Since("1.6.0") - override def topicsMatrix: Matrix = { - if (oldLocalModel.isEmpty) { - oldLocalModel = Some(oldDistributedModel.toLocal) - } - super.topicsMatrix - } - @Since("1.6.0") override def isDistributed: Boolean = true - @Since("1.6.0") - override def logLikelihood(dataset: DataFrame): Double = { - if (oldLocalModel.isEmpty) { - oldLocalModel = Some(oldDistributedModel.toLocal) - } - super.logLikelihood(dataset) - } - - @Since("1.6.0") - override def logPerplexity(dataset: DataFrame): Double = { - if (oldLocalModel.isEmpty) { - oldLocalModel = Some(oldDistributedModel.toLocal) - } - super.logPerplexity(dataset) - } - /** * Log likelihood of the observed tokens in the training set, * given the current parameter estimates: @@ -673,9 +681,9 @@ class LDA @Since("1.6.0") ( val oldModel = oldLDA.run(oldData) val newModel = oldModel match { case m: OldLocalLDAModel => - new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext) + new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext) case m: OldDistributedLDAModel => - new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext) + new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) } copyValues(newModel).setParent(this) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index edb927495e..b634d31cc3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { MLTestingUtils.checkCopy(model) - assert(!model.isInstanceOf[DistributedLDAModel]) + assert(model.isInstanceOf[LocalLDAModel]) assert(model.vocabSize === vocabSize) assert(model.estimatedDocConcentration.size === k) assert(model.topicsMatrix.numRows === vocabSize) @@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.isDistributed) val localModel = model.toLocal - assert(!localModel.isInstanceOf[DistributedLDAModel]) + assert(localModel.isInstanceOf[LocalLDAModel]) // training logLikelihood, logPrior val ll = model.trainingLogLikelihood -- GitLab