diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala index d3fe58a382fbbfe78940ea0b7b145884ffa1c923..8664263935c7b6b98d7b4df60ca18191eb0bd377 100644 --- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,11 +21,12 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers import spark.SparkContext -class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { +class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { val sc = new SparkContext("local", "test") override def afterAll() { @@ -64,8 +65,8 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => (prediction != expected) }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) + // At least 83% of the predictions should be on. + ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 } // Test if we can correctly learn A, B where Y = logistic(A + B*X)