diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index aadc1d31bd4b254999fa8b37705ac41b265fa123..0e08bf013c8d98bfae4457737301405317ed6a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -55,10 +55,20 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") - case Aggregate(_, _, child) if child.isStreaming && outputMode == Append => - throwError( - "Aggregations are not supported on streaming DataFrames/Datasets in " + - "Append output mode. Consider changing output mode to Update.") + case Aggregate(_, _, child) if child.isStreaming => + if (outputMode == Append) { + throwError( + "Aggregations are not supported on streaming DataFrames/Datasets in " + + "Append output mode. Consider changing output mode to Update.") + } + val moreStreamingAggregates = child.find { + case Aggregate(_, _, grandchild) if grandchild.isStreaming => true + case _ => false + } + if (moreStreamingAggregates.nonEmpty) { + throwError("Multiple streaming aggregations are not supported with " + + "streaming DataFrames/Datasets") + } case Join(left, right, joinType, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 50baebe8bf4de7c57a96aea26a807aaa62491afc..674277bdbe15d24639fb1df0987c626a4b863c3c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -95,6 +96,26 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, Seq("aggregation", "append output mode")) + // Multiple streaming aggregations not supported + def aggExprs(name: String): Seq[NamedExpression] = Seq(Count("*").as(name)) + + assertSupportedInStreamingPlan( + "aggregate - multiple batch aggregations", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), batchRelation)), + Update) + + assertSupportedInStreamingPlan( + "aggregate - multiple aggregations but only one streaming aggregation", + Aggregate(Nil, aggExprs("c"), batchRelation).join( + Aggregate(Nil, aggExprs("d"), streamRelation), joinType = Inner), + Update) + + assertNotSupportedInStreamingPlan( + "aggregate - multiple streaming aggregations", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Update, + expectedMsgs = Seq("multiple streaming aggregations")) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", @@ -354,17 +375,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite { val e = intercept[AnalysisException] { testBody } - - if (!expectedMsgs.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { - fail( - s"""Exception message should contain the following substrings: - | - | ${expectedMsgs.mkString("\n ")} - | - |Actual exception message: - | - | ${e.getMessage} - """.stripMargin) + expectedMsgs.foreach { m => + if (!e.getMessage.toLowerCase.contains(m.toLowerCase)) { + fail(s"Exception message should contain: '$m', " + + s"actual exception message:\n\t'${e.getMessage}'") + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0f5fc9ca72d98dac89fdcec32651da28c0193f09..7104d01c4a2a119ffa739c5d4b1b0788b2db1ea8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -84,25 +84,6 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be ) } - test("multiple aggregations") { - val inputData = MemoryStream[Int] - - val aggregated = - inputData.toDF() - .groupBy($"value") - .agg(count("*") as 'count) - .groupBy($"value" % 2) - .agg(sum($"count")) - .as[(Int, Long)] - - testStream(aggregated)( - AddData(inputData, 1, 2, 3, 4), - CheckLastBatch((0, 2), (1, 2)), - AddData(inputData, 1, 3, 5), - CheckLastBatch((1, 5)) - ) - } - testQuietly("midbatch failure") { val inputData = MemoryStream[Int] FailureSinglton.firstTime = true