diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 0bbc9424e6d3b06a41d9ff98edd800c3edf26735..7e8073777396b325da55f7420458b0e0980826e6 100644 --- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -66,7 +66,10 @@ abstract class GeneralizedLinearModel[T: ClassManifest]( * NOTE(shivaram): This is an abstract class rather than a trait as we use * a view bound to convert labels to Double. */ -abstract class GeneralizedLinearAlgorithm[T <% Double, M <: GeneralizedLinearModel[T]] +abstract class GeneralizedLinearAlgorithm[T, M](implicit + t: T => Double, + tManifest: Manifest[T], + methodEv: M <:< GeneralizedLinearModel[T]) extends Logging with Serializable { // We need an optimizer mixin to solve the GLM @@ -84,15 +87,15 @@ abstract class GeneralizedLinearAlgorithm[T <% Double, M <: GeneralizedLinearMod this } - def train(input: RDD[(T, Array[Double])])(implicit mt: Manifest[T]) : M = { - val nfeatures: Int = input.take(1)(0)._2.length + def train(input: RDD[(T, Array[Double])]) : M = { + val nfeatures: Int = input.first()._2.length val initialWeights = Array.fill(nfeatures)(1.0) train(input, initialWeights) } def train( input: RDD[(T, Array[Double])], - initialWeights: Array[Double])(implicit mt: Manifest[T]) + initialWeights: Array[Double]) : M = { // Add a extra variable consisting of all 1.0's for the intercept.