Skip to content
Snippets Groups Projects
Commit 74bbfa91 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added support for generic Hadoop InputFormats and refactored textFile to

use this. Closes #12.
parent 03238cb7
No related branches found
No related tags found
No related merge requests found
...@@ -5,12 +5,15 @@ import mesos.SlaveOffer ...@@ -5,12 +5,15 @@ import mesos.SlaveOffer
import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.FileInputFormat import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
/** A Spark split class that wraps around a Hadoop InputSplit */
@serializable class HadoopSplit(@transient s: InputSplit) @serializable class HadoopSplit(@transient s: InputSplit)
extends Split { extends Split {
val inputSplit = new SerializableWritable[InputSplit](s) val inputSplit = new SerializableWritable[InputSplit](s)
...@@ -19,39 +22,54 @@ extends Split { ...@@ -19,39 +22,54 @@ extends Split {
override def getId() = "HadoopSplit(" + inputSplit.toString + ")" override def getId() = "HadoopSplit(" + inputSplit.toString + ")"
} }
class HadoopTextFile(sc: SparkContext, path: String)
extends RDD[String](sc) {
@transient val conf = new JobConf()
@transient val inputFormat = new TextInputFormat()
FileInputFormat.setInputPaths(conf, path) /**
ConfigureLock.synchronized { inputFormat.configure(conf) } * An RDD that reads a Hadoop file (from HDFS, S3, the local filesystem, etc)
* and represents it as a set of key-value pairs using a given InputFormat.
*/
class HadoopFile[K, V](
sc: SparkContext,
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V])
extends RDD[(K, V)](sc) {
@transient val splits_ : Array[Split] = ConfigureLock.synchronized {
val conf = new JobConf()
FileInputFormat.setInputPaths(conf, path)
val inputFormat = createInputFormat(conf)
val inputSplits = inputFormat.getSplits(conf, sc.scheduler.numCores)
inputSplits.map(x => new HadoopSplit(x): Split).toArray
}
@transient val splits_ = def createInputFormat(conf: JobConf): InputFormat[K, V] = {
inputFormat.getSplits(conf, sc.scheduler.numCores).map(new HadoopSplit(_)).toArray ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
.asInstanceOf[InputFormat[K, V]]
}
override def splits = splits_.asInstanceOf[Array[Split]] override def splits = splits_
override def iterator(split_in: Split) = new Iterator[String] { override def iterator(theSplit: Split) = new Iterator[(K, V)] {
val split = split_in.asInstanceOf[HadoopSplit] val split = theSplit.asInstanceOf[HadoopSplit]
var reader: RecordReader[LongWritable, Text] = null var reader: RecordReader[K, V] = null
ConfigureLock.synchronized { ConfigureLock.synchronized {
val conf = new JobConf() val conf = new JobConf()
conf.set("io.file.buffer.size", val bufferSize = System.getProperty("spark.buffer.size", "65536")
System.getProperty("spark.buffer.size", "65536")) conf.set("io.file.buffer.size", bufferSize)
val tif = new TextInputFormat() val fmt = createInputFormat(conf)
tif.configure(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
reader = tif.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
} }
val lineNum = new LongWritable()
val text = new Text() val key: K = keyClass.newInstance()
val value: V = valueClass.newInstance()
var gotNext = false var gotNext = false
var finished = false var finished = false
override def hasNext: Boolean = { override def hasNext: Boolean = {
if (!gotNext) { if (!gotNext) {
try { try {
finished = !reader.next(lineNum, text) finished = !reader.next(key, value)
} catch { } catch {
case eofe: java.io.EOFException => case eofe: java.io.EOFException =>
finished = true finished = true
...@@ -61,13 +79,15 @@ extends RDD[String](sc) { ...@@ -61,13 +79,15 @@ extends RDD[String](sc) {
!finished !finished
} }
override def next: String = { override def next: (K, V) = {
if (!gotNext) if (!gotNext) {
finished = !reader.next(lineNum, text) finished = !reader.next(key, value)
if (finished) }
throw new java.util.NoSuchElementException("end of stream") if (finished) {
throw new java.util.NoSuchElementException("End of stream")
}
gotNext = false gotNext = false
text.toString (key, value)
} }
} }
...@@ -78,4 +98,21 @@ extends RDD[String](sc) { ...@@ -78,4 +98,21 @@ extends RDD[String](sc) {
} }
} }
/**
* Convenience class for Hadoop files read using TextInputFormat that
* represents the file as an RDD of Strings.
*/
class HadoopTextFile(sc: SparkContext, path: String)
extends MappedRDD[String, (LongWritable, Text)](
new HadoopFile(sc, path, classOf[TextInputFormat],
classOf[LongWritable], classOf[Text]),
{ pair: (LongWritable, Text) => pair._2.toString }
)
/**
* Object used to ensure that only one thread at a time is configuring Hadoop
* InputFormat classes. Apparently configuring them is not thread safe!
*/
object ConfigureLock {} object ConfigureLock {}
...@@ -4,6 +4,9 @@ import java.io._ ...@@ -4,6 +4,9 @@ import java.io._
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
class SparkContext( class SparkContext(
master: String, master: String,
...@@ -42,6 +45,49 @@ extends Logging { ...@@ -42,6 +45,49 @@ extends Logging {
def textFile(path: String): RDD[String] = def textFile(path: String): RDD[String] =
new HadoopTextFile(this, path) new HadoopTextFile(this, path)
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V](path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V])
: RDD[(K, V)] = {
new HadoopFile(this, path, inputFormatClass, keyClass, valueClass)
}
/**
* Smarter version of hadoopFile() that uses class manifests to figure out
* the classes of keys, values and the InputFormat so that users don't need
* to pass them directly.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
hadoopFile(path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]])
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V]): RDD[(K, V)] = {
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass)
}
/**
* Smarter version of sequenceFile() that obtains the key and value classes
* from ClassManifests instead of requiring the user to pass them directly.
*/
def sequenceFile[K, V](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V]): RDD[(K, V)] = {
sequenceFile(path,
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]])
}
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] = def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] =
new UnionRDD(this, rdds) new UnionRDD(this, rdds)
...@@ -59,7 +105,7 @@ extends Logging { ...@@ -59,7 +105,7 @@ extends Logging {
scheduler.stop() scheduler.stop()
scheduler = null scheduler = null
} }
// Wait for the scheduler to be registered // Wait for the scheduler to be registered
def waitForRegister() { def waitForRegister() {
scheduler.waitForRegister() scheduler.waitForRegister()
...@@ -93,7 +139,7 @@ extends Logging { ...@@ -93,7 +139,7 @@ extends Logging {
logInfo("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s") logInfo("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s")
return result return result
} }
// Clean a closure to make it ready to serialized and send to tasks // Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables) // (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = { private[spark] def clean[F <: AnyRef](f: F): F = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment