Skip to content
Snippets Groups Projects
Commit 3f67382e authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-2478] [mllib] DecisionTree Python API

Added experimental Python API for Decision Trees.

API:
* class DecisionTreeModel
** predict() for single examples and RDDs, taking both feature vectors and LabeledPoints
** numNodes()
** depth()
** __str__()
* class DecisionTree
** trainClassifier()
** trainRegressor()
** train()

Examples and testing:
* Added example testing classification and regression with batch prediction: examples/src/main/python/mllib/tree.py
* Have also tested example usage in doc of python/pyspark/mllib/tree.py which tests single-example prediction with dense and sparse vectors

Also: Small bug fix in python/pyspark/mllib/_common.py: In _linear_predictor_typecheck, changed check for RDD to use isinstance() instead of type() in order to catch RDD subclasses.

CC mengxr manishamde

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1727 from jkbradley/decisiontree-python-new and squashes the following commits:

3744488 [Joseph K. Bradley] Renamed test tree.py to decision_tree_runner.py Small updates based on github review.
6b86a9d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
affceb9 [Joseph K. Bradley] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior.  (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature.
67a29bc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
cf46ad7 [Joseph K. Bradley] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more.
aa29873 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
bf21be4 [Joseph K. Bradley] removed old run() func from DecisionTree
fa10ea7 [Joseph K. Bradley] Small style update
7968692 [Joseph K. Bradley] small braces typo fix
e34c263 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
4801b40 [Joseph K. Bradley] Small style update to DecisionTreeSuite
db0eab2 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix2' into decisiontree-python-new
6873fa9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.
93953f1 [Joseph K. Bradley] Likely done with Python API.
6df89a9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
4562c08 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
665ba78 [Joseph K. Bradley] Small updates towards Python DecisionTree API
188cb0d [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new
6622247 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
b8fac57 [Joseph K. Bradley] Finished Python DecisionTree API and example but need to test a bit more.
2b20c61 [Joseph K. Bradley] Small doc and style updates
1b29c13 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new
584449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals
8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type.
376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1
e06e423 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new
bab3f19 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them.
52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
f5a036c [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new
da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification.  Caused problems in past, but fixed now.
8e227ea [Joseph K. Bradley] Changed Strategy so it only requires numClassesForClassification >= 2 for classification
cd1d933 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new
8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features.
8a758db [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new
5fe44ed [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new
2283df8 [Joseph K. Bradley] 2 bug fixes.
73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail.  Will describe bug in next commit.
f825352 [Joseph K. Bradley] Wrote Python API and example for DecisionTree.  Also added toString, depth, and numNodes methods to DecisionTreeModel.
parent e09e18b3
No related branches found
No related tags found
No related merge requests found
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Decision tree classification and regression using MLlib.
"""
import numpy, os, sys
from operator import add
from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.util import MLUtils
def getAccuracy(dtModel, data):
"""
Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint].
"""
seqOp = (lambda acc, x: acc + (x[0] == x[1]))
predictions = dtModel.predict(data.map(lambda x: x.features))
truth = data.map(lambda p: p.label)
trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add)
if data.count() == 0:
return 0
return trainCorrect / (0.0 + data.count())
def getMSE(dtModel, data):
"""
Return mean squared error (MSE) of DecisionTreeModel on the given
RDD[LabeledPoint].
"""
seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1]))
predictions = dtModel.predict(data.map(lambda x: x.features))
truth = data.map(lambda p: p.label)
trainMSE = predictions.zip(truth).aggregate(0, seqOp, add)
if data.count() == 0:
return 0
return trainMSE / (0.0 + data.count())
def reindexClassLabels(data):
"""
Re-index class labels in a dataset to the range {0,...,numClasses-1}.
If all labels in that range already appear at least once,
then the returned RDD is the same one (without a mapping).
Note: If a label simply does not appear in the data,
the index will not include it.
Be aware of this when reindexing subsampled data.
:param data: RDD of LabeledPoint where labels are integer values
denoting labels for a classification problem.
:return: Pair (reindexedData, origToNewLabels) where
reindexedData is an RDD of LabeledPoint with labels in
the range {0,...,numClasses-1}, and
origToNewLabels is a dictionary mapping original labels
to new labels.
"""
# classCounts: class --> # examples in class
classCounts = data.map(lambda x: x.label).countByValue()
numExamples = sum(classCounts.values())
sortedClasses = sorted(classCounts.keys())
numClasses = len(classCounts)
# origToNewLabels: class --> index in 0,...,numClasses-1
if (numClasses < 2):
print >> sys.stderr, \
"Dataset for classification should have at least 2 classes." + \
" The given dataset had only %d classes." % numClasses
exit(1)
origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)])
print "numClasses = %d" % numClasses
print "Per-class example fractions, counts:"
print "Class\tFrac\tCount"
for c in sortedClasses:
frac = classCounts[c] / (numExamples + 0.0)
print "%g\t%g\t%d" % (c, frac, classCounts[c])
if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1):
return (data, origToNewLabels)
else:
reindexedData = \
data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features))
return (reindexedData, origToNewLabels)
def usage():
print >> sys.stderr, \
"Usage: decision_tree_runner [libsvm format data filepath]\n" + \
" Note: This only supports binary classification."
exit(1)
if __name__ == "__main__":
if len(sys.argv) > 2:
usage()
sc = SparkContext(appName="PythonDT")
# Load data.
dataPath = 'data/mllib/sample_libsvm_data.txt'
if len(sys.argv) == 2:
dataPath = sys.argv[1]
if not os.path.isfile(dataPath):
usage()
points = MLUtils.loadLibSVMFile(sc, dataPath)
# Re-index class labels if needed.
(reindexedData, origToNewLabels) = reindexClassLabels(points)
# Train a classifier.
model = DecisionTree.trainClassifier(reindexedData, numClasses=2)
# Print learned tree and stats.
print "Trained DecisionTree for classification:"
print " Model numNodes: %d\n" % model.numNodes()
print " Model depth: %d\n" % model.depth()
print " Training accuracy: %g\n" % getAccuracy(model, reindexedData)
print model
......@@ -30,8 +30,10 @@ from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import LogisticRegressionWithSGD
# Parse a line of text into an MLlib LabeledPoint object
def parsePoint(line):
"""
Parse a line of text into an MLlib LabeledPoint object.
"""
values = [float(s) for s in line.split(' ')]
if values[0] == -1: # Convert -1 labels to 0 for MLlib
values[0] = 0
......
......@@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python
import java.nio.{ByteBuffer, ByteOrder}
import scala.collection.JavaConverters._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
......@@ -29,6 +31,11 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
......@@ -472,6 +479,76 @@ class PythonMLLibAPI extends Serializable {
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
/**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param dataBytesJRDD Training data
* @param categoricalFeaturesInfoJMap Categorical features info, as Java map
*/
def trainDecisionTreeModel(
dataBytesJRDD: JavaRDD[Array[Byte]],
algoStr: String,
numClasses: Int,
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val algo: Algo = algoStr match {
case "classification" => Classification
case "regression" => Regression
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
}
val impurity: Impurity = impurityStr match {
case "gini" => Gini
case "entropy" => Entropy
case "variance" => Variance
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
}
val strategy = new Strategy(
algo = algo,
impurity = impurity,
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
DecisionTree.train(data, strategy)
}
/**
* Predict the label of the given data point.
* This is a Java stub for python DecisionTreeModel.predict()
*
* @param featuresBytes Serialized feature vector for data point
* @return predicted label
*/
def predictDecisionTreeModel(
model: DecisionTreeModel,
featuresBytes: Array[Byte]): Double = {
val features: Vector = deserializeDoubleVector(featuresBytes)
model.predict(features)
}
/**
* Predict the labels of the given data points.
* This is a Java stub for python DecisionTreeModel.predict()
*
* @param dataJRDD A JavaRDD with serialized feature vectors
* @return JavaRDD of serialized predictions
*/
def predictDecisionTreeModel(
model: DecisionTreeModel,
dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
model.predict(data).map(serializeDouble)
}
/**
* Java stub for mllib Statistics.corr(X: RDD[Vector], method: String).
* Returns the correlation matrix serialized into a byte array understood by deserializers in
......@@ -597,4 +674,5 @@ class PythonMLLibAPI extends Serializable {
val s = getSeedOrDefault(seed)
RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector)
}
}
......@@ -56,7 +56,8 @@ class Strategy (
if (algo == Classification) {
require(numClassesForClassification >= 2)
}
val isMulticlassClassification = numClassesForClassification > 2
val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
......
......@@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
requiredMSE: Double) {
val predictions = input.map(x => model.predict(x.features))
val squaredError = predictions.zip(input).map { case (prediction, expected) =>
(prediction - expected.label) * (prediction - expected.label)
val err = prediction - expected.label
err * err
}.sum
val mse = squaredError / input.length
assert(mse <= requiredMSE)
......
......@@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype):
temp_array[...] = array
def _get_unmangled_rdd(data, serializer):
def _get_unmangled_rdd(data, serializer, cache=True):
"""
:param cache: If True, the serialized RDD is cached. (default = True)
WARNING: Users should unpersist() this later!
"""
dataBytes = data.map(serializer)
dataBytes._bypass_serializer = True
dataBytes.cache() # TODO: users should unpersist() this later!
if cache:
dataBytes.cache()
return dataBytes
# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
# _serialized_double_vectors
def _get_unmangled_double_vector_rdd(data):
return _get_unmangled_rdd(data, _serialize_double_vector)
def _get_unmangled_double_vector_rdd(data, cache=True):
"""
Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
_serialized_double_vectors.
:param cache: If True, the serialized RDD is cached. (default = True)
WARNING: Users should unpersist() this later!
"""
return _get_unmangled_rdd(data, _serialize_double_vector, cache)
# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points
def _get_unmangled_labeled_point_rdd(data):
return _get_unmangled_rdd(data, _serialize_labeled_point)
def _get_unmangled_labeled_point_rdd(data, cache=True):
"""
Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points.
:param cache: If True, the serialized RDD is cached. (default = True)
WARNING: Users should unpersist() this later!
"""
return _get_unmangled_rdd(data, _serialize_labeled_point, cache)
# Common functions for dealing with and training linear models
......@@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs):
if x.size != coeffs.shape[0]:
raise RuntimeError("Got sparse vector of size %d; wanted %d" % (
x.size, coeffs.shape[0]))
elif (type(x) == RDD):
elif isinstance(x, RDD):
raise RuntimeError("Bulk predict not yet supported.")
else:
raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
......
......@@ -100,6 +100,7 @@ class ListTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
......@@ -127,9 +128,19 @@ class ListTests(PySparkTestCase):
self.assertTrue(nb_model.predict(features[2]) <= 0)
self.assertTrue(nb_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
dt_model = \
DecisionTree.trainClassifier(rdd, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(-1.0, [0, -1]),
LabeledPoint(1.0, [0, 1]),
......@@ -157,6 +168,14 @@ class ListTests(PySparkTestCase):
self.assertTrue(rr_model.predict(features[2]) <= 0)
self.assertTrue(rr_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
dt_model = \
DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
......@@ -229,6 +248,7 @@ class SciPyTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})),
LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
......@@ -256,9 +276,18 @@ class SciPyTests(PySparkTestCase):
self.assertTrue(nb_model.predict(features[2]) <= 0)
self.assertTrue(nb_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
dt_model = DecisionTree.trainClassifier(rdd, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})),
LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
......@@ -286,6 +315,13 @@ class SciPyTests(PySparkTestCase):
self.assertTrue(rr_model.predict(features[2]) <= 0)
self.assertTrue(rr_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
if __name__ == "__main__":
if not _have_scipy:
......
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from py4j.java_collections import MapConverter
from pyspark import SparkContext, RDD
from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \
_deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \
_deserialize_double
from pyspark.mllib.regression import LabeledPoint
from pyspark.serializers import NoOpSerializer
class DecisionTreeModel(object):
"""
A decision tree model for classification or regression.
EXPERIMENTAL: This is an experimental API.
It will probably be modified for Spark v1.2.
"""
def __init__(self, sc, java_model):
"""
:param sc: Spark context
:param java_model: Handle to Java model object
"""
self._sc = sc
self._java_model = java_model
def __del__(self):
self._sc._gateway.detach(self._java_model)
def predict(self, x):
"""
Predict the label of one or more examples.
:param x: Data point (feature vector),
or an RDD of data points (feature vectors).
"""
pythonAPI = self._sc._jvm.PythonMLLibAPI()
if isinstance(x, RDD):
# Bulk prediction
if x.count() == 0:
return self._sc.parallelize([])
dataBytes = _get_unmangled_double_vector_rdd(x, cache=False)
jSerializedPreds = \
pythonAPI.predictDecisionTreeModel(self._java_model,
dataBytes._jrdd)
serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer())
return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes)))
else:
# Assume x is a single data point.
x_ = _serialize_double_vector(x)
return pythonAPI.predictDecisionTreeModel(self._java_model, x_)
def numNodes(self):
return self._java_model.numNodes()
def depth(self):
return self._java_model.depth()
def __str__(self):
return self._java_model.toString()
class DecisionTree(object):
"""
Learning algorithm for a decision tree model
for classification or regression.
EXPERIMENTAL: This is an experimental API.
It will probably be modified for Spark v1.2.
Example usage:
>>> from numpy import array, ndarray
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
>>>
>>> data = [
... LabeledPoint(0.0, [0.0]),
... LabeledPoint(1.0, [1.0]),
... LabeledPoint(1.0, [2.0]),
... LabeledPoint(1.0, [3.0])
... ]
>>>
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
>>> print(model)
DecisionTreeModel classifier
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)
Predict: 1.0
>>> model.predict(array([1.0])) > 0
True
>>> model.predict(array([0.0])) == 0
True
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
>>> model.predict(array([0.0, 1.0])) == 1
True
>>> model.predict(array([0.0, 0.0])) == 0
True
>>> model.predict(SparseVector(2, {1: 1.0})) == 1
True
>>> model.predict(SparseVector(2, {1: 0.0})) == 0
True
"""
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
impurity="gini", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for classification.
:param data: Training data: RDD of LabeledPoint.
Labels are integers {0,1,...,numClasses}.
:param numClasses: Number of classes for classification.
:param categoricalFeaturesInfo: Map from categorical feature index
to number of categories.
Any feature not in this map
is treated as continuous.
:param impurity: Supported values: "entropy" or "gini"
:param maxDepth: Max depth of tree.
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
return DecisionTree.train(data, "classification", numClasses,
categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
@staticmethod
def trainRegressor(data, categoricalFeaturesInfo={},
impurity="variance", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for regression.
:param data: Training data: RDD of LabeledPoint.
Labels are real numbers.
:param categoricalFeaturesInfo: Map from categorical feature index
to number of categories.
Any feature not in this map
is treated as continuous.
:param impurity: Supported values: "variance"
:param maxDepth: Max depth of tree.
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
return DecisionTree.train(data, "regression", 0,
categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
@staticmethod
def train(data, algo, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins=100):
"""
Train a DecisionTreeModel for classification or regression.
:param data: Training data: RDD of LabeledPoint.
For classification, labels are integers
{0,1,...,numClasses}.
For regression, labels are real numbers.
:param algo: "classification" or "regression"
:param numClasses: Number of classes for classification.
:param categoricalFeaturesInfo: Map from categorical feature index
to number of categories.
Any feature not in this map
is treated as continuous.
:param impurity: For classification: "entropy" or "gini".
For regression: "variance".
:param maxDepth: Max depth of tree.
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
sc = data.context
dataBytes = _get_unmangled_labeled_point_rdd(data)
categoricalFeaturesInfoJMap = \
MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, algo,
numClasses, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
def _test():
import doctest
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
if __name__ == "__main__":
_test()
......@@ -16,6 +16,7 @@
#
import numpy as np
import warnings
from pyspark.mllib.linalg import Vectors, SparseVector
from pyspark.mllib.regression import LabeledPoint
......@@ -29,9 +30,9 @@ class MLUtils:
Helper methods to load, save and pre-process data used in MLlib.
"""
@deprecated
@staticmethod
def _parse_libsvm_line(line, multiclass):
warnings.warn("deprecated", DeprecationWarning)
return _parse_libsvm_line(line)
@staticmethod
......@@ -67,9 +68,9 @@ class MLUtils:
" but got " % type(v))
return " ".join(items)
@deprecated
@staticmethod
def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None):
warnings.warn("deprecated", DeprecationWarning)
return loadLibSVMFile(sc, path, numFeatures, minPartitions)
@staticmethod
......@@ -106,7 +107,6 @@ class MLUtils:
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush()
>>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
>>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
>>> tempFile.close()
>>> type(examples[0]) == LabeledPoint
True
......@@ -115,20 +115,18 @@ class MLUtils:
>>> type(examples[1]) == LabeledPoint
True
>>> print examples[1]
(0.0,(6,[],[]))
(-1.0,(6,[],[]))
>>> type(examples[2]) == LabeledPoint
True
>>> print examples[2]
(0.0,(6,[1,3,5],[4.0,5.0,6.0]))
>>> multiclass_examples[1].label
-1.0
(-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
"""
lines = sc.textFile(path, minPartitions)
parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l))
if numFeatures <= 0:
parsed.cache()
numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1
numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1
return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2])))
@staticmethod
......
......@@ -71,6 +71,7 @@ run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/tests.py"
run_test "pyspark/mllib/util.py"
if [[ $FAILED == 0 ]]; then
echo -en "\033[32m" # Green
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment