From 1018a1c1eb33eefbfb9025fac7a1cdafc5cbf8f8 Mon Sep 17 00:00:00 2001
From: Timothy Hunter <timhunter@databricks.com>
Date: Wed, 13 Apr 2016 11:06:42 -0700
Subject: [PATCH] [SPARK-14568][ML] Instrumentation framework for logistic
 regression

## What changes were proposed in this pull request?

This adds extra logging information about a `LogisticRegression` estimator when being fit on a dataset. With this PR, you see the following extra lines when running the example in the documentation:

```
16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training: numPartitions=1 storageLevel=StorageLevel(disk=true, memory=true, offheap=false, deserialized=true, replication=1)
16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): {"regParam":0.3,"elasticNetParam":0.8,"maxIter":10}
...
16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numClasses=2
16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numFeatures=692
...
16/04/13 07:19:01 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training finished
```

## How was this patch tested?

This PR was manually tested.

Author: Timothy Hunter <timhunter@databricks.com>

Closes #12331 from thunterdb/1604-instrumentation.
---
 .../classification/LogisticRegression.scala   |  11 +-
 .../spark/ml/util/Instrumentation.scala       | 117 ++++++++++++++++++
 2 files changed, 127 insertions(+), 1 deletion(-)
 create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 4a3fe5c663..c2b440059b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -273,6 +273,10 @@ class LogisticRegression @Since("1.2.0") (
 
     if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
 
+    val instr = Instrumentation.create(this, instances)
+    instr.logParams(regParam, elasticNetParam, standardization, threshold,
+      maxIter, tol, fitIntercept)
+
     val (summarizer, labelSummarizer) = {
       val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
         instance: Instance) =>
@@ -291,6 +295,9 @@ class LogisticRegression @Since("1.2.0") (
     val numClasses = histogram.length
     val numFeatures = summarizer.mean.size
 
+    instr.logNumClasses(numClasses)
+    instr.logNumFeatures(numFeatures)
+
     val (coefficients, intercept, objectiveHistory) = {
       if (numInvalid != 0) {
         val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
@@ -444,7 +451,9 @@ class LogisticRegression @Since("1.2.0") (
       $(labelCol),
       $(featuresCol),
       objectiveHistory)
-    model.setSummary(logRegSummary)
+    val m = model.setSummary(logRegSummary)
+    instr.logSuccess(m)
+    m
   }
 
   @Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
new file mode 100644
index 0000000000..7e57cefc44
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.Param
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
+
+/**
+ * A small wrapper that defines a training session for an estimator, and some methods to log
+ * useful information during this session.
+ *
+ * A new instance is expected to be created within fit().
+ *
+ * @param estimator the estimator that is being fit
+ * @param dataset the training dataset
+ * @tparam E the type of the estimator
+ */
+private[ml] class Instrumentation[E <: Estimator[_]] private (
+    estimator: E, dataset: RDD[_]) extends Logging {
+
+  private val id = Instrumentation.counter.incrementAndGet()
+  private val prefix = {
+    val className = estimator.getClass.getSimpleName
+    s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
+  }
+
+  init()
+
+  private def init(): Unit = {
+    log(s"training: numPartitions=${dataset.partitions.length}" +
+      s" storageLevel=${dataset.getStorageLevel}")
+  }
+
+  /**
+   * Logs a message with a prefix that uniquely identifies the training session.
+   */
+  def log(msg: String): Unit = {
+    logInfo(prefix + msg)
+  }
+
+  /**
+   * Logs the value of the given parameters for the estimator being used in this session.
+   */
+  def logParams(params: Param[_]*): Unit = {
+    val pairs: Seq[(String, JValue)] = for {
+      p <- params
+      value <- estimator.get(p)
+    } yield {
+      val cast = p.asInstanceOf[Param[Any]]
+      p.name -> parse(cast.jsonEncode(value))
+    }
+    log(compact(render(map2jvalue(pairs.toMap))))
+  }
+
+  def logNumFeatures(num: Long): Unit = {
+    log(compact(render("numFeatures" -> num)))
+  }
+
+  def logNumClasses(num: Long): Unit = {
+    log(compact(render("numClasses" -> num)))
+  }
+
+  /**
+   * Logs the successful completion of the training session and the value of the learned model.
+   */
+  def logSuccess(model: Model[_]): Unit = {
+    log(s"training finished")
+  }
+}
+
+/**
+ * Some common methods for logging information about a training session.
+ */
+private[ml] object Instrumentation {
+  private val counter = new AtomicLong(0)
+
+  /**
+   * Creates an instrumentation object for a training session.
+   */
+  def create[E <: Estimator[_]](
+      estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
+    create[E](estimator, dataset.rdd)
+  }
+
+  /**
+   * Creates an instrumentation object for a training session.
+   */
+  def create[E <: Estimator[_]](
+      estimator: E, dataset: RDD[_]): Instrumentation[E] = {
+    new Instrumentation[E](estimator, dataset)
+  }
+
+}
-- 
GitLab