diff --git a/docs/ml-features.md b/docs/ml-features.md index e86f9edc4f68bfb8febaf3a5e5e7642d2ad63491..63ea3e5db7ac9746cc44169631a17ee06c5d4708 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -106,6 +106,95 @@ for features_label in featurized.select("features", "label").take(3): </div> </div> +## Word2Vec + +`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. + +Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. + +<div class="codetabs"> +<div data-lang="scala" markdown="1"> +{% highlight scala %} +import org.apache.spark.ml.feature.Word2Vec + +// Input data: Each row is a bag of words from a sentence or document. +val documentDF = sqlContext.createDataFrame(Seq( + "Hi I heard about Spark".split(" "), + "I wish Java could use case classes".split(" "), + "Logistic regression models are neat".split(" ") +).map(Tuple1.apply)).toDF("text") + +// Learn a mapping from words to Vectors. +val word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0) +val model = word2Vec.fit(documentDF) +val result = model.transform(documentDF) +result.select("result").take(3).foreach(println) +{% endhighlight %} +</div> + +<div data-lang="java" markdown="1"> +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; + +JavaSparkContext jsc = ... +SQLContext sqlContext = ... + +// Input data: Each row is a bag of words from a sentence or document. +JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), + RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), + RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) +}); +DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + +// Learn a mapping from words to Vectors. +Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); +Word2VecModel model = word2Vec.fit(documentDF); +DataFrame result = model.transform(documentDF); +for (Row r: result.select("result").take(3)) { + System.out.println(r); +} +{% endhighlight %} +</div> + +<div data-lang="python" markdown="1"> +{% highlight python %} +from pyspark.ml.feature import Word2Vec + +# Input data: Each row is a bag of words from a sentence or document. +documentDF = sqlContext.createDataFrame([ + ("Hi I heard about Spark".split(" "), ), + ("I wish Java could use case classes".split(" "), ), + ("Logistic regression models are neat".split(" "), ) +], ["text"]) +# Learn a mapping from words to Vectors. +word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="text", outputCol="result") +model = word2Vec.fit(documentDF) +result = model.transform(documentDF) +for feature in result.select("result").take(3): + print(feature) +{% endhighlight %} +</div> +</div> # Feature Transformers diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..39c70157f83c0e20a465ad3c6a91ed8f459db702 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; + +public class JavaWord2VecSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testJavaWord2Vec() { + JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), + RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), + RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + + Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); + Word2VecModel model = word2Vec.fit(documentDF); + DataFrame result = model.transform(documentDF); + + for (Row r: result.select("result").collect()) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + Assert.assertEquals(polyFeatures.length, 3); + } + } +}