From 0a984aa155fb7f532fe87620dcf1a2814c5b8b49 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Tue, 19 Aug 2014 22:16:22 -0700
Subject: [PATCH] [SPARK-3142][MLLIB] output shuffle data directly in Word2Vec

Sorry I didn't realize this in #2043. Ishiihara

Author: Xiangrui Meng <meng@databricks.com>

Closes #2049 from mengxr/more-w2v and squashes the following commits:

050b1c5 [Xiangrui Meng] output shuffle data directly
---
 .../apache/spark/mllib/feature/Word2Vec.scala | 23 ++++++++++---------
 1 file changed, 12 insertions(+), 11 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index c3375ed44f..fc14447053 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging {
         }
         val syn0Local = model._1
         val syn1Local = model._2
-        val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
-        var index = 0
-        while(index < vocabSize) {
-          if (syn0Modify(index) != 0) {
-            synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+        // Only output modified vectors.
+        Iterator.tabulate(vocabSize) { index =>
+          if (syn0Modify(index) > 0) {
+            Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+          } else {
+            None
           }
-          if (syn1Modify(index) != 0) {
-            synOut += ((index + vocabSize,
-              syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+        }.flatten ++ Iterator.tabulate(vocabSize) { index =>
+          if (syn1Modify(index) > 0) {
+            Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+          } else {
+            None
           }
-          index += 1
-        }
-        synOut.toIterator
+        }.flatten
       }
       val synAgg = partial.reduceByKey { case (v1, v2) =>
           blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
-- 
GitLab