From 1816eb3bef930407dc9e083de08f5105725c55d1 Mon Sep 17 00:00:00 2001 From: zero323 <zero323@users.noreply.github.com> Date: Wed, 24 May 2017 19:57:44 +0800 Subject: [PATCH] [SPARK-20631][FOLLOW-UP] Fix incorrect tests. ## What changes were proposed in this pull request? - Fix incorrect tests for `_check_thresholds`. - Move test to `ParamTests`. ## How was this patch tested? Unit tests. Author: zero323 <zero323@users.noreply.github.com> Closes #18085 from zero323/SPARK-20631-FOLLOW-UP. --- python/pyspark/ml/tests.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a3393c6248..0daf29d59c 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -404,6 +404,18 @@ class ParamTests(PySparkTestCase): self.assertEqual(tp._paramMap, copied_no_extra) self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + class EvaluatorTests(SparkSessionTestCase): @@ -807,18 +819,6 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass - def logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegressionModel - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - def _compare_params(self, m1, m2, param): """ Compare 2 ML Params instances for the given param, and assert both have the same param value -- GitLab