From 9434280cfd1db94dc9d52bb0ace8283e710e3124 Mon Sep 17 00:00:00 2001
From: Bago Amirbekian <bago@databricks.com>
Date: Tue, 23 May 2017 20:56:01 -0700
Subject: [PATCH] [SPARK-20861][ML][PYTHON] Delegate looping over paramMaps to
 estimators

Changes:

pyspark.ml Estimators can take either a list of param maps or a dict of params. This change allows the CrossValidator and TrainValidationSplit Estimators to pass through lists of param maps to the underlying estimators so that those estimators can handle parallelization when appropriate (eg distributed hyper parameter tuning).

Testing:

Existing unit tests.

Author: Bago Amirbekian <bago@databricks.com>

Closes #18077 from MrBago/delegate_params.
---
 python/pyspark/ml/tuning.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ffeb4459e1..b64858214d 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -18,14 +18,11 @@
 import itertools
 import numpy as np
 
-from pyspark import SparkContext
 from pyspark import since, keyword_only
 from pyspark.ml import Estimator, Model
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
-from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
-from pyspark.ml.common import inherit_doc, _py2java
 
 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
            'TrainValidationSplitModel']
@@ -232,8 +229,9 @@ class CrossValidator(Estimator, ValidatorParams):
             condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
             validation = df.filter(condition)
             train = df.filter(~condition)
+            models = est.fit(train, epm)
             for j in range(numModels):
-                model = est.fit(train, epm[j])
+                model = models[j]
                 # TODO: duplicate evaluator to take extra params from input
                 metric = eva.evaluate(model.transform(validation, epm[j]))
                 metrics[j] += metric/nFolds
@@ -388,8 +386,9 @@ class TrainValidationSplit(Estimator, ValidatorParams):
         condition = (df[randCol] >= tRatio)
         validation = df.filter(condition)
         train = df.filter(~condition)
+        models = est.fit(train, epm)
         for j in range(numModels):
-            model = est.fit(train, epm[j])
+            model = models[j]
             metric = eva.evaluate(model.transform(validation, epm[j]))
             metrics[j] += metric
         if eva.isLargerBetter():
-- 
GitLab