From d41d6c48207159490c1e1d9cc54015725cfa41b2 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Wed, 26 Aug 2015 16:04:44 -0700
Subject: [PATCH] [SPARK-10305] [SQL] fix create DataFrame from Python class

cc jkbradley

Author: Davies Liu <davies@databricks.com>

Closes #8470 from davies/fix_create_df.
---
 python/pyspark/sql/tests.py | 12 ++++++++++++
 python/pyspark/sql/types.py |  6 ++++++
 2 files changed, 18 insertions(+)

diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aacfb34c77..cd32e26c64 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint):
     __UDT__ = PythonOnlyUDT()
 
 
+class MyObject(object):
+    def __init__(self, key, value):
+        self.key = key
+        self.value = value
+
+
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
     def test_data_type_eq(self):
@@ -383,6 +389,12 @@ class SQLTests(ReusedPySparkTestCase):
         df = self.sqlCtx.inferSchema(rdd)
         self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
 
+    def test_create_dataframe_from_objects(self):
+        data = [MyObject(1, "1"), MyObject(2, "2")]
+        df = self.sqlCtx.createDataFrame(data)
+        self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
+        self.assertEqual(df.first(), Row(key=1, value="1"))
+
     def test_select_null_literal(self):
         df = self.sqlCtx.sql("select null as col")
         self.assertEquals(Row(col=None), df.first())
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ed4e5b594b..94e581a783 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -537,6 +537,9 @@ class StructType(DataType):
                 return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
             elif isinstance(obj, (tuple, list)):
                 return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
+            elif hasattr(obj, "__dict__"):
+                d = obj.__dict__
+                return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields))
             else:
                 raise ValueError("Unexpected tuple %r with StructType" % obj)
         else:
@@ -544,6 +547,9 @@ class StructType(DataType):
                 return tuple(obj.get(n) for n in self.names)
             elif isinstance(obj, (list, tuple)):
                 return tuple(obj)
+            elif hasattr(obj, "__dict__"):
+                d = obj.__dict__
+                return tuple(d.get(n) for n in self.names)
             else:
                 raise ValueError("Unexpected tuple %r with StructType" % obj)
 
-- 
GitLab