From 5553198fe521fb38b600b7687f7780d89a6e1cb9 Mon Sep 17 00:00:00 2001
From: Burak Yavuz <brkyvz@gmail.com>
Date: Wed, 29 Apr 2015 19:13:47 -0700
Subject: [PATCH] [SPARK-7156][SQL] Addressed follow up comments for
 randomSplit

small fixes regarding comments in PR #5761

cc rxin

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #5795 from brkyvz/split-followup and squashes the following commits:

369c522 [Burak Yavuz] changed wording a little
1ea456f [Burak Yavuz] Addressed follow up comments
---
 python/pyspark/sql/dataframe.py                            | 7 ++++++-
 .../src/main/scala/org/apache/spark/sql/DataFrame.scala    | 2 +-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3074af3ed2..5908ebc990 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -437,6 +437,10 @@ class DataFrame(object):
     def randomSplit(self, weights, seed=None):
         """Randomly splits this :class:`DataFrame` with the provided weights.
 
+        :param weights: list of doubles as weights with which to split the DataFrame. Weights will
+            be normalized if they don't sum up to 1.0.
+        :param seed: The seed for sampling.
+
         >>> splits = df4.randomSplit([1.0, 2.0], 24)
         >>> splits[0].count()
         1
@@ -445,7 +449,8 @@ class DataFrame(object):
         3
         """
         for w in weights:
-            assert w >= 0.0, "Negative weight value: %s" % w
+            if w < 0.0:
+                raise ValueError("Weights must be positive. Found weight value: %s" % w)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
         rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
         return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 0d02e14c21..2669300029 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -752,7 +752,7 @@ class DataFrame private[sql](
    * @param seed Seed for sampling.
    * @group dfops
    */
-  def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+  private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
     randomSplit(weights.toArray, seed)
   }
 
-- 
GitLab