diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 1c0c536c4fb3d5090902faf4ba80522c3ae0fb9d..9e28dfbb9145db68fee5e5cd1c07b67a148f844f 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -63,7 +63,10 @@ class LogisticRegressionModel(LinearModel): def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept - prob = 1/(1 + exp(-margin)) + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + prob = 1 - 1 / (1 + exp(margin)) return 1 if prob > 0.5 else 0