Skip to content
Snippets Groups Projects
Commit 14bb465b authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #201 from rxin/mappartitions

Use the proper partition index in mapPartitionsWIthIndex

mapPartitionsWithIndex uses TaskContext.partitionId as the partition index. TaskContext.partitionId used to be identical to the partition index in a RDD. However, pull request #186 introduced a scenario (with partition pruning) that the two can be different. This pull request uses the right partition index in all mapPartitionsWithIndex related calls.

Also removed the extra MapPartitionsWIthContextRDD and put all the mapPartitions related functionality in MapPartitionsRDD.
parents eb4296c8 e9ff13ec
No related branches found
No related tags found
No related merge requests found
......@@ -20,18 +20,16 @@ package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
private[spark]
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: Iterator[T] => Iterator[U],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
override val partitioner =
if (preservesPartitioning) firstParent[T].partitioner else None
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext) =
f(firstParent[T].iterator(split, context))
f(context, split.index, firstParent[T].iterator(split, context))
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
/**
* A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
* TaskContext, the closure can either get access to the interruptible flag or get the index
* of the partition in the RDD.
*/
private[spark]
class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
f(context, firstParent[T].iterator(split, context))
}
......@@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest](
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
......@@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest](
*/
def mapPartitions[U: ClassManifest](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
......@@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest](
*/
def mapPartitionsWithIndex[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
......@@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest](
def mapPartitionsWithContext[U: ClassManifest](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
......@@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest](
def mapWith[A: ClassManifest, U: ClassManifest]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
val a = constructA(context.partitionId)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.map(t => f(t, a))
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}, preservesPartitioning)
}
/**
......@@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest](
def flatMapWith[A: ClassManifest, U: ClassManifest]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
val a = constructA(context.partitionId)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.flatMap(t => f(t, a))
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}, preservesPartitioning)
}
/**
......@@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest](
* partition with the index of that partition.
*/
def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
val a = constructA(context.partitionId)
mapPartitionsWithIndex { (index, iter) =>
val a = constructA(index)
iter.map(t => {f(t, a); t})
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}.foreach(_ => {})
}
/**
......@@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest](
* partition with the index of that partition.
*/
def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
val a = constructA(context.partitionId)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.filter(t => p(t, a))
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}, preservesPartitioning = true)
}
/**
......
......@@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
testCheckpointing(r => new MapPartitionsWithContextRDD(r,
(context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment