Skip to content
Snippets Groups Projects
Commit 37f0ab70 authored by hqzizania's avatar hqzizania Committed by Yanbo Liang
Browse files

[SPARK-17090][FOLLOW-UP][ML] Add expert param support to SharedParamsCodeGen

## What changes were proposed in this pull request?

Add expert param support to SharedParamsCodeGen where aggregationDepth a expert param is added.

Author: hqzizania <hqzizania@gmail.com>

Closes #14738 from hqzizania/SPARK-17090-minor.
parent 6d93f9e0
No related branches found
No related tags found
No related merge requests found
......@@ -80,7 +80,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
"empty, default value is 'auto'", Some("\"auto\"")),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
isValid = "ParamValidators.gtEq(2)"))
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
......@@ -95,7 +95,8 @@ private[shared] object SharedParamsCodeGen {
doc: String,
defaultValueStr: Option[String] = None,
isValid: String = "",
finalMethods: Boolean = true) {
finalMethods: Boolean = true,
isExpertParam: Boolean = false) {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
require(doc.nonEmpty) // TODO: more rigorous on doc
......@@ -153,6 +154,11 @@ private[shared] object SharedParamsCodeGen {
} else {
""
}
val groupStr = if (param.isExpertParam) {
Array("expertParam", "expertGetParam")
} else {
Array("param", "getParam")
}
val methodStr = if (param.finalMethods) {
"final def"
} else {
......@@ -167,11 +173,11 @@ private[shared] object SharedParamsCodeGen {
|
| /**
| * Param for $doc.
| * @group param
| * @group ${groupStr(0)}
| */
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault
| /** @group getParam */
| /** @group ${groupStr(1)} */
| $methodStr get$Name: $T = $$($name)
|}
|""".stripMargin
......
......@@ -397,13 +397,13 @@ private[ml] trait HasAggregationDepth extends Params {
/**
* Param for suggested depth for treeAggregate (>= 2).
* @group param
* @group expertParam
*/
final val aggregationDepth: IntParam = new IntParam(this, "aggregationDepth", "suggested depth for treeAggregate (>= 2)", ParamValidators.gtEq(2))
setDefault(aggregationDepth, 2)
/** @group getParam */
/** @group expertGetParam */
final def getAggregationDepth: Int = $(aggregationDepth)
}
// scalastyle:on
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment