Skip to content
Snippets Groups Projects
Commit 630a994e authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-9893] User guide with Java test suite for VectorSlicer

Add user guide for `VectorSlicer`, with Java test suite and Python version VectorSlicer.

Note that Python version does not support selecting by names now.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #8267 from yinxusen/SPARK-9893.
parent f01c4220
No related branches found
No related tags found
No related merge requests found
......@@ -1477,6 +1477,139 @@ print(output.select("features", "clicked").first())
</div>
</div>
# Feature Selectors
## VectorSlicer
`VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector with a
sub-array of the original features. It is useful for extracting features from a vector column.
`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column
whose values are selected via those indices. There are two types of indices,
1. Integer indices that represents the indices into the vector, `setIndices()`;
2. String indices that represents the names of features into the vector, `setNames()`.
*This requires the vector column to have an `AttributeGroup` since the implementation matches on
the name field of an `Attribute`.*
Specification by integer and string are both acceptable. Moreover, you can use integer index and
string name simultaneously. At least one feature must be selected. Duplicate features are not
allowed, so there can be no overlap between selected indices and names. Note that if names of
features are selected, an exception will be threw out when encountering with empty input attributes.
The output vector will order features with the selected indices first (in the order given),
followed by the selected names (in the order given).
**Examples**
Suppose that we have a DataFrame with the column `userFeatures`:
~~~
userFeatures
------------------
[0.0, 10.0, 0.5]
~~~
`userFeatures` is a vector column that contains three user features. Assuming that the first column
of `userFeatures` are all zeros, so we want to remove it and only the last two columns are selected.
The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a new vector
column named `features`:
~~~
userFeatures | features
------------------|-----------------------------
[0.0, 10.0, 0.5] | [10.0, 0.5]
~~~
Suppose also that we have a potential input attributes for the `userFeatures`, i.e.
`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them.
~~~
userFeatures | features
------------------|-----------------------------
[0.0, 10.0, 0.5] | [10.0, 0.5]
["f1", "f2", "f3"] | ["f2", "f3"]
~~~
<div class="codetabs">
<div data-lang="scala" markdown="1">
[`VectorSlicer`](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) takes an input
column name with specified indices or names and an output column name.
{% highlight scala %}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.feature.VectorSlicer
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(-2.0, 2.3, 0.0)
)
val defaultAttr = NumericAttribute.defaultAttr
val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName)
val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]])
val dataRDD = sc.parallelize(data).map(Row.apply)
val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField()))
val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features")
slicer.setIndices(1).setNames("f3")
// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3"))
val output = slicer.transform(dataset)
println(output.select("userFeatures", "features").first())
{% endhighlight %}
</div>
<div data-lang="java" markdown="1">
[`VectorSlicer`](api/java/org/apache/spark/ml/feature/VectorSlicer.html) takes an input column name
with specified indices or names and an output column name.
{% highlight java %}
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.*;
import static org.apache.spark.sql.types.DataTypes.*;
Attribute[] attrs = new Attribute[]{
NumericAttribute.defaultAttr().withName("f1"),
NumericAttribute.defaultAttr().withName("f2"),
NumericAttribute.defaultAttr().withName("f3")
};
AttributeGroup group = new AttributeGroup("userFeatures", attrs);
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
));
DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"})
DataFrame output = vectorSlicer.transform(dataset);
System.out.println(output.select("userFeatures", "features").first());
{% endhighlight %}
</div>
</div>
## RFormula
`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature;
import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
public class JavaVectorSlicerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
jsql = new SQLContext(jsc);
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
}
@Test
public void vectorSlice() {
Attribute[] attrs = new Attribute[]{
NumericAttribute.defaultAttr().withName("f1"),
NumericAttribute.defaultAttr().withName("f2"),
NumericAttribute.defaultAttr().withName("f3")
};
AttributeGroup group = new AttributeGroup("userFeatures", attrs);
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
));
DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
DataFrame output = vectorSlicer.transform(dataset);
for (Row r : output.select("userFeatures", "features").take(2)) {
Vector features = r.getAs(1);
Assert.assertEquals(features.size(), 2);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment