diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 2941984e19eb69b1a19f6f5a5cfcc92adbb9437c..eb79135b9d0c542f3e77c2029b73c9930c9958d7 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -62,6 +62,6 @@ def launch_gateway():
     # Import the classes used by PySpark
     java_import(gateway.jvm, "org.apache.spark.api.java.*")
     java_import(gateway.jvm, "org.apache.spark.api.python.*")
-    java_import(gateway.jvm, "org.apache.spark.mllib.api.*")
+    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
     java_import(gateway.jvm, "scala.Tuple2")
     return gateway