diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index eaecd1b49a..79df0625d8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -120,22 +120,13 @@ object CometSortArray extends CometExpressionSerde[SortArray] with CodegenDispat "When `" + CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key + "=true`, sorting on" + " floating-point types is not 100% compatible with Spark") - override def getUnsupportedReasons(): Seq[String] = Seq( - "Nested arrays with `Struct` or `Null` child values are not supported natively and will" + - " fall back to Spark.") - - private def supportedSortArrayElementType( - dt: DataType, - nestedInArray: Boolean = false): Boolean = { + private def supportedSortArrayElementType(dt: DataType): Boolean = { dt match { - // DataFusion's array_sort compares nested arrays through Arrow's rank kernel. - // That kernel does not support Struct or Null child values, - // so array>> and array> would fail at runtime. - case _: NullType if !nestedInArray => + case _: NullType => true case ArrayType(elementType, _) => - supportedSortArrayElementType(elementType, nestedInArray = true) - case StructType(fields) if !nestedInArray => + supportedSortArrayElementType(elementType) + case StructType(fields) => fields.forall(f => supportedSortArrayElementType(f.dataType)) case _ => supportedScalarSortElementType(dt) diff --git a/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql index 1ced53394d..6bc7b07f0f 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql @@ -295,10 +295,10 @@ INSERT INTO test_sort_array_nested_struct VALUES (array()), (NULL) -query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +query SELECT sort_array(arr) FROM test_sort_array_nested_struct -query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +query SELECT sort_array(arr, false) FROM test_sort_array_nested_struct -- literal arguments @@ -391,7 +391,7 @@ SELECT sort_array(array(NULL, NULL)), sort_array(cast(NULL as array)) -query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +query SELECT sort_array( array( array(named_struct('a', 2)),