diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aacfb34c77618eefa9fc657342c9df19edbec4b2..cd32e26c64f222772c69bb7eebb1eecee35c8234 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint): __UDT__ = PythonOnlyUDT() +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -383,6 +389,12 @@ class SQLTests(ReusedPySparkTestCase): df = self.sqlCtx.inferSchema(rdd) self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.sqlCtx.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) + def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") self.assertEquals(Row(col=None), df.first()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ed4e5b594bd617bd9cc3fa7b616fd8e7b067eb31..94e581a78364cdf32deae0ba5fb178ebc4a1e328 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -537,6 +537,9 @@ class StructType(DataType): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -544,6 +547,9 @@ class StructType(DataType): return tuple(obj.get(n) for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(d.get(n) for n in self.names) else: raise ValueError("Unexpected tuple %r with StructType" % obj)