Skip to content
Snippets Groups Projects
  • Glenn Weidner's avatar
    c5aca0c2
    [SPARK-7427] [PYSPARK] Make sharedParams match in Scala, Python · c5aca0c2
    Glenn Weidner authored
    Modified 2 files:
    python/pyspark/ml/param/_shared_params_code_gen.py
    python/pyspark/ml/param/shared.py
    
    Generated shared.py on Linux using Python 2.6.6 on Redhat Enterprise Linux Server 6.6.
    python _shared_params_code_gen.py > shared.py
    
    Only changed maxIter, regParam, rawPredictionCol based on strings from SharedParamsCodeGen.scala.  Note warning was displayed when committing shared.py:
    warning: LF will be replaced by CRLF in python/pyspark/ml/param/shared.py.
    
    Author: Glenn Weidner <gweidner@us.ibm.com>
    
    Closes #6023 from gweidner/br-7427 and squashes the following commits:
    
    db72e32 [Glenn Weidner] [SPARK-7427] [PySpark] Make sharedParams match in Scala, Python
    825e4a9 [Glenn Weidner] [SPARK-7427] [PySpark] Make sharedParams match in Scala, Python
    e6a865e [Glenn Weidner] [SPARK-7427] [PySpark] Make sharedParams match in Scala, Python
    1eee702 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    1ac10e5 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    cafd104 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    9bea1eb [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    4a35c20 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    9790cbe [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    d9c30f4 [Glenn Weidner] [SPARK-7275] [SQL] [WIP] Make LogicalRelation public
    c5aca0c2
    History
    [SPARK-7427] [PYSPARK] Make sharedParams match in Scala, Python
    Glenn Weidner authored
    Modified 2 files:
    python/pyspark/ml/param/_shared_params_code_gen.py
    python/pyspark/ml/param/shared.py
    
    Generated shared.py on Linux using Python 2.6.6 on Redhat Enterprise Linux Server 6.6.
    python _shared_params_code_gen.py > shared.py
    
    Only changed maxIter, regParam, rawPredictionCol based on strings from SharedParamsCodeGen.scala.  Note warning was displayed when committing shared.py:
    warning: LF will be replaced by CRLF in python/pyspark/ml/param/shared.py.
    
    Author: Glenn Weidner <gweidner@us.ibm.com>
    
    Closes #6023 from gweidner/br-7427 and squashes the following commits:
    
    db72e32 [Glenn Weidner] [SPARK-7427] [PySpark] Make sharedParams match in Scala, Python
    825e4a9 [Glenn Weidner] [SPARK-7427] [PySpark] Make sharedParams match in Scala, Python
    e6a865e [Glenn Weidner] [SPARK-7427] [PySpark] Make sharedParams match in Scala, Python
    1eee702 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    1ac10e5 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    cafd104 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    9bea1eb [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    4a35c20 [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    9790cbe [Glenn Weidner] Merge remote-tracking branch 'upstream/master'
    d9c30f4 [Glenn Weidner] [SPARK-7275] [SQL] [WIP] Make LogicalRelation public
tests.py 5.36 KiB
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Unit tests for Spark ML Python APIs.
"""

import sys

if sys.version_info[:2] <= (2, 6):
    try:
        import unittest2 as unittest
    except ImportError:
        sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
        sys.exit(1)
else:
    import unittest

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer


class MockDataset(DataFrame):

    def __init__(self):
        self.index = 0


class MockTransformer(Transformer):

    def __init__(self):
        super(MockTransformer, self).__init__()
        self.fake = Param(self, "fake", "fake")
        self.dataset_index = None
        self.fake_param_value = None

    def transform(self, dataset, params={}):
        self.dataset_index = dataset.index
        if self.fake in params:
            self.fake_param_value = params[self.fake]
        dataset.index += 1
        return dataset


class MockEstimator(Estimator):

    def __init__(self):
        super(MockEstimator, self).__init__()
        self.fake = Param(self, "fake", "fake")
        self.dataset_index = None
        self.fake_param_value = None
        self.model = None

    def fit(self, dataset, params={}):
        self.dataset_index = dataset.index
        if self.fake in params:
            self.fake_param_value = params[self.fake]
        model = MockModel()
        self.model = model
        return model


class MockModel(MockTransformer, Model):

    def __init__(self):
        super(MockModel, self).__init__()


class PipelineTests(PySparkTestCase):

    def test_pipeline(self):
        dataset = MockDataset()
        estimator0 = MockEstimator()
        transformer1 = MockTransformer()
        estimator2 = MockEstimator()
        transformer3 = MockTransformer()
        pipeline = Pipeline() \
            .setStages([estimator0, transformer1, estimator2, transformer3])
        pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
        self.assertEqual(0, estimator0.dataset_index)
        self.assertEqual(0, estimator0.fake_param_value)
        model0 = estimator0.model
        self.assertEqual(0, model0.dataset_index)
        self.assertEqual(1, transformer1.dataset_index)
        self.assertEqual(1, transformer1.fake_param_value)
        self.assertEqual(2, estimator2.dataset_index)
        model2 = estimator2.model
        self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
                                                "not be called during fit.")
        dataset = pipeline_model.transform(dataset)
        self.assertEqual(2, model0.dataset_index)
        self.assertEqual(3, transformer1.dataset_index)
        self.assertEqual(4, model2.dataset_index)
        self.assertEqual(5, transformer3.dataset_index)
        self.assertEqual(6, dataset.index)


class TestParams(HasMaxIter, HasInputCol):
    """
    A subclass of Params mixed with HasMaxIter and HasInputCol.
    """

    def __init__(self):
        super(TestParams, self).__init__()
        self._setDefault(maxIter=10)


class ParamTests(PySparkTestCase):

    def test_param(self):
        testParams = TestParams()
        maxIter = testParams.maxIter
        self.assertEqual(maxIter.name, "maxIter")
        self.assertEqual(maxIter.doc, "max number of iterations (>= 0)")
        self.assertTrue(maxIter.parent is testParams)

    def test_params(self):
        testParams = TestParams()
        maxIter = testParams.maxIter
        inputCol = testParams.inputCol

        params = testParams.params
        self.assertEqual(params, [inputCol, maxIter])

        self.assertTrue(testParams.hasDefault(maxIter))
        self.assertFalse(testParams.isSet(maxIter))
        self.assertTrue(testParams.isDefined(maxIter))
        self.assertEqual(testParams.getMaxIter(), 10)
        testParams.setMaxIter(100)
        self.assertTrue(testParams.isSet(maxIter))
        self.assertEquals(testParams.getMaxIter(), 100)

        self.assertFalse(testParams.hasDefault(inputCol))
        self.assertFalse(testParams.isSet(inputCol))
        self.assertFalse(testParams.isDefined(inputCol))
        with self.assertRaises(KeyError):
            testParams.getInputCol()

        self.assertEquals(
            testParams.explainParams(),
            "\n".join(["inputCol: input column name (undefined)",
                       "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))


if __name__ == "__main__":
    unittest.main()