diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a3393c62486577efb3c9b6d945e9c68c7ca2724a..0daf29d59cb743f7e07035be7b87aa956f439104 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -404,6 +404,18 @@ class ParamTests(PySparkTestCase): self.assertEqual(tp._paramMap, copied_no_extra) self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + class EvaluatorTests(SparkSessionTestCase): @@ -807,18 +819,6 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass - def logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegressionModel - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - def _compare_params(self, m1, m2, param): """ Compare 2 ML Params instances for the given param, and assert both have the same param value