diff --git a/examples/src/main/java/spark/examples/JavaKMeans.java b/examples/src/main/java/spark/examples/JavaKMeans.java
new file mode 100644
index 0000000000000000000000000000000000000000..c76930b8c40ae4b16da1c0b2df1a7e3d2c838046
--- /dev/null
+++ b/examples/src/main/java/spark/examples/JavaKMeans.java
@@ -0,0 +1,111 @@
+package spark.examples;
+
+import scala.Tuple2;
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+import spark.api.java.function.PairFunction;
+import spark.util.Vector;
+
+import java.util.List;
+import java.util.Map;
+
+public class JavaKMeans {
+
+    /** Parses numbers split by whitespace to a vector */
+    static Vector parseVector(String line) {
+        String[] splits = line.split(" ");
+        double[] data = new double[splits.length];
+        int i = 0;
+        for (String s : splits)
+            data[i] = Double.parseDouble(splits[i++]);
+        return new Vector(data);
+    }
+
+    /** Computes the vector to which the input vector is closest using squared distance */
+    static int closestPoint(Vector p, List<Vector> centers) {
+        int bestIndex = 0;
+        double closest = Double.POSITIVE_INFINITY;
+        for (int i = 0; i < centers.size(); i++) {
+            double tempDist = p.squaredDist(centers.get(i));
+            if (tempDist < closest) {
+                closest = tempDist;
+                bestIndex = i;
+            }
+        }
+        return bestIndex;
+    }
+
+    /** Computes the mean across all vectors in the input set of vectors */
+    static Vector average(List<Vector> ps) {
+        int numVectors = ps.size();
+        Vector out = new Vector(ps.get(0).elements());
+        // start from i = 1 since we already copied index 0 above
+        for (int i = 1; i < numVectors; i++) {
+            out.addInPlace(ps.get(i));
+        }
+        return out.divide(numVectors);
+    }
+
+    public static void main(String[] args) throws Exception {
+        if (args.length < 4) {
+            System.err.println("Usage: SparkKMeans <master> <file> <k> <convergeDist>");
+            System.exit(1);
+        }
+        JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
+                System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+        String path = args[1];
+        int K = Integer.parseInt(args[2]);
+        double convergeDist = Double.parseDouble(args[3]);
+
+        JavaRDD<Vector> data = sc.textFile(path).map(
+                new Function<String, Vector>() {
+                    @Override
+                    public Vector call(String line) throws Exception {
+                        return parseVector(line);
+                    }
+                }
+        ).cache();
+
+        final List<Vector> centroids = data.takeSample(false, K, 42);
+
+        double tempDist;
+        do {
+            // allocate each vector to closest centroid
+            JavaPairRDD<Integer, Vector> closest = data.map(
+                    new PairFunction<Vector, Integer, Vector>() {
+                        @Override
+                        public Tuple2<Integer, Vector> call(Vector vector) throws Exception {
+                            return new Tuple2<Integer, Vector>(
+                                    closestPoint(vector, centroids), vector);
+                        }
+                    }
+            );
+
+            // group by cluster id and average the vectors within each cluster to compute centroids
+            JavaPairRDD<Integer, List<Vector>> pointsGroup = closest.groupByKey();
+            Map<Integer, Vector> newCentroids = pointsGroup.mapValues(
+                    new Function<List<Vector>, Vector>() {
+                        public Vector call(List<Vector> ps) throws Exception {
+                            return average(ps);
+                        }
+                    }).collectAsMap();
+            tempDist = 0.0;
+            for (int i = 0; i < K; i++) {
+                tempDist += centroids.get(i).squaredDist(newCentroids.get(i));
+            }
+            for (Map.Entry<Integer, Vector> t: newCentroids.entrySet()) {
+                centroids.set(t.getKey(), t.getValue());
+            }
+            System.out.println("Finished iteration (delta = " + tempDist + ")");
+        } while (tempDist > convergeDist);
+
+        System.out.println("Final centers:");
+        for (Vector c : centroids)
+            System.out.println(c);
+
+        System.exit(0);
+
+}
+}
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala
index 7c21ea12fb72430089d0a4166c8b74fac7677277..4161c59fead2046851428f799f1ecbc07b1eedf8 100644
--- a/examples/src/main/scala/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala
@@ -64,6 +64,7 @@ object SparkKMeans {
       for (newP <- newPoints) {
         kPoints(newP._1) = newP._2
       }
+      println("Finished iteration (delta = " + tempDist + ")")
     }
 
     println("Final centers:")