Skip to content
Snippets Groups Projects
Commit 0d17593b authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-14461][ML] GLM training summaries should provide solver

## What changes were proposed in this pull request?
GLM training summaries should provide solver.

## How was this patch tested?
Unit tests.

cc jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12253 from yanboliang/spark-14461.
parent b0adb9f5
No related branches found
No related tags found
No related merge requests found
...@@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val ...@@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName, predictionColName,
model, model,
wlsModel.diagInvAtWA.toArray, wlsModel.diagInvAtWA.toArray,
1) 1,
getSolver)
return model.setSummary(trainingSummary) return model.setSummary(trainingSummary)
} }
...@@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val ...@@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName, predictionColName,
model, model,
irlsModel.diagInvAtWA.toArray, irlsModel.diagInvAtWA.toArray,
irlsModel.numIterations) irlsModel.numIterations,
getSolver)
model.setSummary(trainingSummary) model.setSummary(trainingSummary)
} }
...@@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr ...@@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* @param model the model that should be summarized * @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations * @param numIterations number of iterations
* @param solver the solver algorithm used for model training
*/ */
@Since("2.0.0") @Since("2.0.0")
@Experimental @Experimental
...@@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( ...@@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String, @Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel, @Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double], private val diagInvAtWA: Array[Double],
@Since("2.0.0") val numIterations: Int) extends Serializable { @Since("2.0.0") val numIterations: Int,
@Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._ import GeneralizedLinearRegression._
......
...@@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite ...@@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3) assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
} }
test("glm summary: binomial family with weight") { test("glm summary: binomial family with weight") {
...@@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite ...@@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3) assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
} }
test("glm summary: poisson family with weight") { test("glm summary: poisson family with weight") {
...@@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite ...@@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3) assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
} }
test("glm summary: gamma family with weight") { test("glm summary: gamma family with weight") {
...@@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite ...@@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3) assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
} }
test("read/write") { test("read/write") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment