Skip to content
Snippets Groups Projects
Commit dcb896fd authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-11712][ML] Make spark.ml LDAModel be abstract

Per discussion in the initial Pipelines LDA PR [https://github.com/apache/spark/pull/9513], we should make LDAModel abstract and create a LocalLDAModel. This code simplification should be done before the 1.6 release to ensure API compatibility in future releases.

CC feynmanliang mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9678 from jkbradley/lda-pipelines-2.
parent bc092966
No related branches found
No related tags found
No related merge requests found
...@@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM ...@@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* Model fitted by [[LDA]]. * Model fitted by [[LDA]].
* *
* @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
* @param oldLocalModel Underlying spark.mllib model.
* If this model was produced by Online LDA, then this is the
* only model representation.
* If this model was produced by EM, then this local
* representation may be built lazily.
* @param sqlContext Used to construct local DataFrames for returning query results * @param sqlContext Used to construct local DataFrames for returning query results
*/ */
@Since("1.6.0") @Since("1.6.0")
@Experimental @Experimental
class LDAModel private[ml] ( sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String, @Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int, @Since("1.6.0") val vocabSize: Int,
@Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel],
@Since("1.6.0") @transient protected val sqlContext: SQLContext) @Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging { extends Model[LDAModel] with LDAParams with Logging {
/** Returns underlying spark.mllib model */ // NOTE to developers:
// This abstraction should contain all important functionality for basic LDA usage.
// Specializations of this class can contain expert-only functionality.
/**
* Underlying spark.mllib model.
* If this model was produced by Online LDA, then this is the only model representation.
* If this model was produced by EM, then this local representation may be built lazily.
*/
@Since("1.6.0") @Since("1.6.0")
protected def getModel: OldLDAModel = oldLocalModel match { protected def oldLocalModel: OldLocalLDAModel
case Some(m) => m
case None => /** Returns underlying spark.mllib model, which may be local or distributed */
// Should never happen. @Since("1.6.0")
throw new RuntimeException("LDAModel required local model format," + protected def getModel: OldLDAModel
" but the underlying model is missing.")
}
/** /**
* The features for LDA should be a [[Vector]] representing the word counts in a document. * The features for LDA should be a [[Vector]] representing the word counts in a document.
...@@ -352,16 +352,17 @@ class LDAModel private[ml] ( ...@@ -352,16 +352,17 @@ class LDAModel private[ml] (
@Since("1.6.0") @Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value) def setSeed(value: Long): this.type = set(seed, value)
@Since("1.6.0") /**
override def copy(extra: ParamMap): LDAModel = { * Transforms the input dataset.
val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) *
copyValues(copied, extra).setParent(parent) * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
} * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*/
@Since("1.6.0") @Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
if ($(topicDistributionCol).nonEmpty) { if ($(topicDistributionCol).nonEmpty) {
val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext)) val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
} else { } else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" + logWarning("LDAModel.transform was called without any output columns. Set an output column" +
...@@ -388,56 +389,50 @@ class LDAModel private[ml] ( ...@@ -388,56 +389,50 @@ class LDAModel private[ml] (
* This is a matrix of size vocabSize x k, where each column is a topic. * This is a matrix of size vocabSize x k, where each column is a topic.
* No guarantees are given about the ordering of the topics. * No guarantees are given about the ordering of the topics.
* *
* WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM, * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by
* then this method could involve collecting a large amount of data to the driver * the Expectation-Maximization ("em") [[optimizer]], then this method could involve
* (on the order of vocabSize x k). * collecting a large amount of data to the driver (on the order of vocabSize x k).
*/ */
@Since("1.6.0") @Since("1.6.0")
def topicsMatrix: Matrix = getModel.topicsMatrix def topicsMatrix: Matrix = oldLocalModel.topicsMatrix
/** Indicates whether this instance is of type [[DistributedLDAModel]] */ /** Indicates whether this instance is of type [[DistributedLDAModel]] */
@Since("1.6.0") @Since("1.6.0")
def isDistributed: Boolean = false def isDistributed: Boolean
/** /**
* Calculates a lower bound on the log likelihood of the entire corpus. * Calculates a lower bound on the log likelihood of the entire corpus.
* *
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010). * See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
* *
* WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* a large [[topicsMatrix]] to the driver. This implementation may be changed in the * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* future. * This implementation may be changed in the future.
* *
* @param dataset test corpus to use for calculating log likelihood * @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus * @return variational lower bound on the log likelihood of the entire corpus
*/ */
@Since("1.6.0") @Since("1.6.0")
def logLikelihood(dataset: DataFrame): Double = oldLocalModel match { def logLikelihood(dataset: DataFrame): Double = {
case Some(m) => val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logLikelihood(oldDataset)
m.logLikelihood(oldDataset)
case None =>
// Should never happen.
throw new RuntimeException("LocalLDAModel.logLikelihood was called," +
" but the underlying model is missing.")
} }
/** /**
* Calculate an upper bound bound on perplexity. (Lower is better.) * Calculate an upper bound bound on perplexity. (Lower is better.)
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010). * See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
* *
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating perplexity * @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token. * @return Variational upper bound on log perplexity per token.
*/ */
@Since("1.6.0") @Since("1.6.0")
def logPerplexity(dataset: DataFrame): Double = oldLocalModel match { def logPerplexity(dataset: DataFrame): Double = {
case Some(m) => val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logPerplexity(oldDataset)
m.logPerplexity(oldDataset)
case None =>
// Should never happen.
throw new RuntimeException("LocalLDAModel.logPerplexity was called," +
" but the underlying model is missing.")
} }
/** /**
...@@ -468,10 +463,43 @@ class LDAModel private[ml] ( ...@@ -468,10 +463,43 @@ class LDAModel private[ml] (
/** /**
* :: Experimental :: * :: Experimental ::
* *
* Distributed model fitted by [[LDA]] using Expectation-Maximization (EM). * Local (non-distributed) model fitted by [[LDA]].
*
* This model stores the inferred topics only; it does not store info about the training dataset.
*/
@Since("1.6.0")
@Experimental
class LocalLDAModel private[ml] (
uid: String,
vocabSize: Int,
@Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
sqlContext: SQLContext)
extends LDAModel(uid, vocabSize, sqlContext) {
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
}
override protected def getModel: OldLDAModel = oldLocalModel
@Since("1.6.0")
override def isDistributed: Boolean = false
}
/**
* :: Experimental ::
*
* Distributed model fitted by [[LDA]].
* This type of model is currently only produced by Expectation-Maximization (EM).
* *
* This model stores the inferred topics, the full training dataset, and the topic distribution * This model stores the inferred topics, the full training dataset, and the topic distribution
* for each training document. * for each training document.
*
* @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping
* [[copy()]] cheap.
*/ */
@Since("1.6.0") @Since("1.6.0")
@Experimental @Experimental
...@@ -479,59 +507,39 @@ class DistributedLDAModel private[ml] ( ...@@ -479,59 +507,39 @@ class DistributedLDAModel private[ml] (
uid: String, uid: String,
vocabSize: Int, vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel, private val oldDistributedModel: OldDistributedLDAModel,
sqlContext: SQLContext) sqlContext: SQLContext,
extends LDAModel(uid, vocabSize, None, sqlContext) { private var oldLocalModelOption: Option[OldLocalLDAModel])
extends LDAModel(uid, vocabSize, sqlContext) {
override protected def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
oldLocalModelOption = Some(oldDistributedModel.toLocal)
}
oldLocalModelOption.get
}
override protected def getModel: OldLDAModel = oldDistributedModel
/** /**
* Convert this distributed model to a local representation. This discards info about the * Convert this distributed model to a local representation. This discards info about the
* training dataset. * training dataset.
*
* WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/ */
@Since("1.6.0") @Since("1.6.0")
def toLocal: LDAModel = { def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
}
@Since("1.6.0")
override protected def getModel: OldLDAModel = oldDistributedModel
@Since("1.6.0") @Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = { override def copy(extra: ParamMap): DistributedLDAModel = {
val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext) val copied =
if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
copyValues(copied, extra).setParent(parent) copyValues(copied, extra).setParent(parent)
copied copied
} }
@Since("1.6.0")
override def topicsMatrix: Matrix = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.topicsMatrix
}
@Since("1.6.0") @Since("1.6.0")
override def isDistributed: Boolean = true override def isDistributed: Boolean = true
@Since("1.6.0")
override def logLikelihood(dataset: DataFrame): Double = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.logLikelihood(dataset)
}
@Since("1.6.0")
override def logPerplexity(dataset: DataFrame): Double = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.logPerplexity(dataset)
}
/** /**
* Log likelihood of the observed tokens in the training set, * Log likelihood of the observed tokens in the training set,
* given the current parameter estimates: * given the current parameter estimates:
...@@ -673,9 +681,9 @@ class LDA @Since("1.6.0") ( ...@@ -673,9 +681,9 @@ class LDA @Since("1.6.0") (
val oldModel = oldLDA.run(oldData) val oldModel = oldLDA.run(oldData)
val newModel = oldModel match { val newModel = oldModel match {
case m: OldLocalLDAModel => case m: OldLocalLDAModel =>
new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext) new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
case m: OldDistributedLDAModel => case m: OldDistributedLDAModel =>
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext) new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
} }
copyValues(newModel).setParent(this) copyValues(newModel).setParent(this)
} }
......
...@@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
MLTestingUtils.checkCopy(model) MLTestingUtils.checkCopy(model)
assert(!model.isInstanceOf[DistributedLDAModel]) assert(model.isInstanceOf[LocalLDAModel])
assert(model.vocabSize === vocabSize) assert(model.vocabSize === vocabSize)
assert(model.estimatedDocConcentration.size === k) assert(model.estimatedDocConcentration.size === k)
assert(model.topicsMatrix.numRows === vocabSize) assert(model.topicsMatrix.numRows === vocabSize)
...@@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.isDistributed) assert(model.isDistributed)
val localModel = model.toLocal val localModel = model.toLocal
assert(!localModel.isInstanceOf[DistributedLDAModel]) assert(localModel.isInstanceOf[LocalLDAModel])
// training logLikelihood, logPrior // training logLikelihood, logPrior
val ll = model.trainingLogLikelihood val ll = model.trainingLogLikelihood
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment