diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index fe57d17f1ec14b5764a51d227294695297a64819..1f18a6e9ff8a5b1204f12aee7263cae1db8d23ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
 /**
  * Returns the input formatted according do printf-style format strings
  */
-case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
+case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
 
   require(children.nonEmpty, "printf() should take at least 1 argument")
 
@@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
   private def format: Expression = children(0)
   private def args: Seq[Expression] = children.tail
 
+  override def inputTypes: Seq[AbstractDataType] =
+    StringType :: List.fill(children.size - 1)(AnyDataType)
+
+
   override def eval(input: InternalRow): Any = {
     val pattern = format.eval(input)
     if (pattern == null) {
@@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
     }
   }
 
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val pattern = children.head.gen(ctx)
+
+    val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
+    val argListCode = argListGen.map(_._2.code + "\n")
+
+    val argListString = argListGen.foldLeft("")((s, v) => {
+      val nullSafeString =
+        if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
+          // Java primitives get boxed in order to allow null values.
+          s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
+            s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
+        } else {
+          s"(${v._2.isNull}) ? null : ${v._2.primitive}"
+        }
+      s + "," + nullSafeString
+    })
+
+    val form = ctx.freshName("formatter")
+    val formatter = classOf[java.util.Formatter].getName
+    val sb = ctx.freshName("sb")
+    val stringBuffer = classOf[StringBuffer].getName
+    s"""
+      ${pattern.code}
+      boolean ${ev.isNull} = ${pattern.isNull};
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${argListCode.mkString}
+        $stringBuffer $sb = new $stringBuffer();
+        $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
+        $form.format(${pattern.primitive}.toString() $argListString);
+        ${ev.primitive} = UTF8String.fromString($sb.toString());
+      }
+     """
+  }
+
   override def prettyName: String = "printf"
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 96c540ab36f083b79cb02b9c3ed0c49167765473..3c2d88731beb4121708cf026baa3739e11dcb6b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("FORMAT") {
-    val f = 'f.string.at(0)
-    val d1 = 'd.int.at(1)
-    val s1 = 's.int.at(2)
-
-    val row1 = create_row("aa%d%s", 12, "cc")
-    val row2 = create_row(null, 12, "cc")
-    checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+    checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
     checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
-    checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+    checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
+    checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
 
-    checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
-    checkEvaluation(StringFormat(f, d1, s1), null, row2)
+    checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
+    checkEvaluation(
+      StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
+    checkEvaluation(
+      StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
   }
 
   test("INSTR") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d94d7335828c55540798e0abd76987f78f28c131..e5ff8ae7e3179e51c4975197e7426efe6f71f486 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1741,6 +1741,17 @@ object functions {
    */
   def rtrim(e: Column): Column = StringTrimRight(e.expr)
 
+  /**
+   * Format strings in printf-style.
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def formatString(format: Column, arguments: Column*): Column = {
+    StringFormat((format +: arguments).map(_.expr): _*)
+  }
+
   /**
    * Format strings in printf-style.
    * NOTE: `format` is the string value of the formatter, not column name.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index d1f855903ca4b7049cdfa4a4cd797f3eb38a05c7..3702e73b4e74f035f2f58ad76270cf0540b7394d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest {
     checkAnswer(
       df.selectExpr("printf(a, b, c)"),
       Row("aa123cc"))
+
+    val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")
+
+    checkAnswer(
+      df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
+      Row("aa123cc", "aa123cc"))
+
+    checkAnswer(
+      df2.selectExpr("printf(a, b, c)"),
+      Row("aa123cc"))
   }
 
   test("string instr function") {