From 9434280cfd1db94dc9d52bb0ace8283e710e3124 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian <bago@databricks.com> Date: Tue, 23 May 2017 20:56:01 -0700 Subject: [PATCH] [SPARK-20861][ML][PYTHON] Delegate looping over paramMaps to estimators Changes: pyspark.ml Estimators can take either a list of param maps or a dict of params. This change allows the CrossValidator and TrainValidationSplit Estimators to pass through lists of param maps to the underlying estimators so that those estimators can handle parallelization when appropriate (eg distributed hyper parameter tuning). Testing: Existing unit tests. Author: Bago Amirbekian <bago@databricks.com> Closes #18077 from MrBago/delegate_params. --- python/pyspark/ml/tuning.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ffeb4459e1..b64858214d 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,14 +18,11 @@ import itertools import numpy as np -from pyspark import SparkContext from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed -from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand -from pyspark.ml.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] @@ -232,8 +229,9 @@ class CrossValidator(Estimator, ValidatorParams): condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) validation = df.filter(condition) train = df.filter(~condition) + models = est.fit(train, epm) for j in range(numModels): - model = est.fit(train, epm[j]) + model = models[j] # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric/nFolds @@ -388,8 +386,9 @@ class TrainValidationSplit(Estimator, ValidatorParams): condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) + models = est.fit(train, epm) for j in range(numModels): - model = est.fit(train, epm[j]) + model = models[j] metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric if eva.isLargerBetter(): -- GitLab