diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 54756edd9345dd13b2990446950e98b291eaef4a..cfd9c558ff67eddd59f5edd684d02eddbb63d620 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1241,26 +1241,29 @@ class SQLTests(ReusedPySparkTestCase): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - with self.assertRaises(ValueError): - struct1 = StructType().add("name") + self.assertRaises(ValueError, lambda: StructType().add("name")) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) for field in struct1: @@ -1273,12 +1276,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertIs(struct1["f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - with self.assertRaises(KeyError): - not_a_field = struct1["f9"] - with self.assertRaises(IndexError): - not_a_field = struct1[9] - with self.assertRaises(TypeError): - not_a_field = struct1[9.9] + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c376805c32738aeab738f6f92cf61caa9c9665d8..ecb8eb9a2f2facf27301d2f67c8d783abaf8c2a7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -446,9 +446,12 @@ class StructType(DataType): This is the data type representing a :class:`Row`. - Iterating a :class:`StructType` will iterate its :class:`StructField`s. + Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. + .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead + to get a list of field names. + >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) @@ -563,6 +566,16 @@ class StructType(DataType): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def fieldNames(self): + """ + Returns all field names in a list. + + >>> struct = StructType([StructField("f1", StringType(), True)]) + >>> struct.fieldNames() + ['f1'] + """ + return list(self.names) + def needConversion(self): # We need convert Row()/namedtuple into tuple() return True