diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 1f36eced3d08f51ce520d4f051506c3f78819e0b..4663f16b5f5dccad16427e4c6a37003a7210ad91 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -223,20 +223,18 @@ class ImputerModel private[ml] (
 
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    var outputDF = dataset
     val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
 
-    $(inputCols).zip($(outputCols)).zip(surrogates).foreach {
+    val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
       case ((inputCol, outputCol), surrogate) =>
         val inputType = dataset.schema(inputCol).dataType
         val ic = col(inputCol)
-        outputDF = outputDF.withColumn(outputCol,
-          when(ic.isNull, surrogate)
+        when(ic.isNull, surrogate)
           .when(ic === $(missingValue), surrogate)
           .otherwise(ic)
-          .cast(inputType))
+          .cast(inputType)
     }
-    outputDF.toDF()
+    dataset.withColumns($(outputCols), newCols).toDF()
   }
 
   override def transformSchema(schema: StructType): StructType = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ab0c4126bcbddffa83d4353a437f2811b3083e7d..f2a76a506eb6fd1cb8c830a4fdbbf05941b664c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2083,22 +2083,40 @@ class Dataset[T] private[sql](
    * @group untypedrel
    * @since 2.0.0
    */
-  def withColumn(colName: String, col: Column): DataFrame = {
+  def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))
+
+  /**
+   * Returns a new Dataset by adding columns or replacing the existing columns that has
+   * the same names.
+   */
+  private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
+    require(colNames.size == cols.size,
+      s"The size of column names: ${colNames.size} isn't equal to " +
+        s"the size of columns: ${cols.size}")
+    SchemaUtils.checkColumnNameDuplication(
+      colNames,
+      "in given column names",
+      sparkSession.sessionState.conf.caseSensitiveAnalysis)
+
     val resolver = sparkSession.sessionState.analyzer.resolver
     val output = queryExecution.analyzed.output
-    val shouldReplace = output.exists(f => resolver(f.name, colName))
-    if (shouldReplace) {
-      val columns = output.map { field =>
-        if (resolver(field.name, colName)) {
-          col.as(colName)
-        } else {
-          Column(field)
-        }
+
+    val columnMap = colNames.zip(cols).toMap
+
+    val replacedAndExistingColumns = output.map { field =>
+      columnMap.find { case (colName, _) =>
+        resolver(field.name, colName)
+      } match {
+        case Some((colName: String, col: Column)) => col.as(colName)
+        case _ => Column(field)
       }
-      select(columns : _*)
-    } else {
-      select(Column("*"), col.as(colName))
     }
+
+    val newColumns = columnMap.filter { case (colName, col) =>
+      !output.exists(f => resolver(f.name, colName))
+    }.map { case (colName, col) => col.as(colName) }
+
+    select(replacedAndExistingColumns ++ newColumns : _*)
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0e2f2e5a193e1642ef7b7ea90ea72963f480730a..672deeac597f103372d32359b7cb3ff816a8b040 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
   }
 
+  test("withColumns") {
+    val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
+      Seq(col("key") + 1, col("key") + 2))
+    checkAnswer(
+      df,
+      testData.collect().map { case Row(key: Int, value: String) =>
+        Row(key, value, key + 1, key + 2)
+      }.toSeq)
+    assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
+
+    val err = intercept[IllegalArgumentException] {
+      testData.toDF().withColumns(Seq("newCol1"),
+        Seq(col("key") + 1, col("key") + 2))
+    }
+    assert(
+      err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2"))
+
+    val err2 = intercept[AnalysisException] {
+      testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
+        Seq(col("key") + 1, col("key") + 2))
+    }
+    assert(err2.getMessage.contains("Found duplicate column(s)"))
+  }
+
+  test("withColumns: case sensitive") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
+        Seq(col("key") + 1, col("key") + 2))
+      checkAnswer(
+        df,
+        testData.collect().map { case Row(key: Int, value: String) =>
+          Row(key, value, key + 1, key + 2)
+        }.toSeq)
+      assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1"))
+
+      val err = intercept[AnalysisException] {
+        testData.toDF().withColumns(Seq("newCol1", "newCol1"),
+          Seq(col("key") + 1, col("key") + 2))
+      }
+      assert(err.getMessage.contains("Found duplicate column(s)"))
+    }
+  }
+
   test("replace column using withColumn") {
     val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
     val df3 = df2.withColumn("x", df2("x") + 1)
@@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Row(2) :: Row(3) :: Row(4) :: Nil)
   }
 
+  test("replace column using withColumns") {
+    val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
+    val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
+      Seq(df2("x") + 1, df2("y"), df2("y") + 1))
+    checkAnswer(
+      df3.select("x", "newCol1", "newCol2"),
+      Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
+  }
+
   test("drop column using drop") {
     val df = testData.drop("key")
     checkAnswer(