From bd167f83b0adbff43b6cd910893103001679317b Mon Sep 17 00:00:00 2001 From: Andrey Kouznetsov <bearonsails@gmail.com> Date: Tue, 19 Mar 2013 17:15:15 +0400 Subject: [PATCH] call setConf from input format if it is Configurable --- core/src/main/scala/spark/rdd/HadoopRDD.scala | 7 +++++++ core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index a6322dc58d..cbf5512e24 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -17,6 +17,7 @@ import org.apache.hadoop.util.ReflectionUtils import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} import spark.util.NextIterator +import org.apache.hadoop.conf.Configurable /** @@ -50,6 +51,9 @@ class HadoopRDD[K, V]( override def getPartitions: Array[Partition] = { val inputFormat = createInputFormat(conf) + if (inputFormat.isInstanceOf[Configurable]) { + inputFormat.asInstanceOf[Configurable].setConf(conf) + } val inputSplits = inputFormat.getSplits(conf, minSplits) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { @@ -69,6 +73,9 @@ class HadoopRDD[K, V]( val conf = confBroadcast.value.value val fmt = createInputFormat(conf) + if (fmt.isInstanceOf[Configurable]) { + fmt.asInstanceOf[Configurable].setConf(conf) + } reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index df2361025c..bdd974590a 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -3,7 +3,7 @@ package spark.rdd import java.text.SimpleDateFormat import java.util.Date -import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -42,6 +42,9 @@ class NewHadoopRDD[K, V]( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + if (inputFormat.isInstanceOf[Configurable]) { + inputFormat.asInstanceOf[Configurable].setConf(conf) + } val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) @@ -57,6 +60,9 @@ class NewHadoopRDD[K, V]( val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance + if (format.isInstanceOf[Configurable]) { + format.asInstanceOf[Configurable].setConf(conf) + } val reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) -- GitLab