Skip to content
Snippets Groups Projects
Commit 0d17593b authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-14461][ML] GLM training summaries should provide solver

## What changes were proposed in this pull request?
GLM training summaries should provide solver.

## How was this patch tested?
Unit tests.

cc jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12253 from yanboliang/spark-14461.
parent b0adb9f5
No related branches found
No related tags found
No related merge requests found
......@@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
1)
1,
getSolver)
return model.setSummary(trainingSummary)
}
......@@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
irlsModel.numIterations)
irlsModel.numIterations,
getSolver)
model.setSummary(trainingSummary)
}
......@@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
* @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
......@@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
@Since("2.0.0") val numIterations: Int) extends Serializable {
@Since("2.0.0") val numIterations: Int,
@Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._
......
......@@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
}
test("glm summary: binomial family with weight") {
......@@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
}
test("glm summary: poisson family with weight") {
......@@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
}
test("glm summary: gamma family with weight") {
......@@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
assert(summary.solver === "irls")
}
test("read/write") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment