diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 3b92b92004324..476855163f9f8 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -39,6 +39,7 @@ use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, sum_distinct, }; +use datafusion_functions_nested::expr_fn::{array_filter, array_transform, make_array}; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; @@ -78,8 +79,8 @@ use datafusion_expr::{ CreateMemoryTable, CreateView, DdlStatement, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, TableType, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, cast, col, - create_udf, exists, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, - when, wildcard, + create_udf, exists, in_subquery, lambda, lambda_var, lit, out_ref_col, placeholder, + scalar_subquery, when, wildcard, }; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::AggregateExprBuilder; @@ -7242,3 +7243,22 @@ async fn test_grouping_with_alias() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_unresolved_lambda_variable() -> Result<()> { + table_with_mixed_lists().await?.with_column( + "c", + array_transform( + make_array(vec![col("list")]), + lambda( + ["x"], + array_filter( + lambda_var("x"), + lambda(["y"], lambda_var("y").gt_eq(lit(1))), + ), + ), + ), + )?; + + Ok(()) +} diff --git a/datafusion/functions-nested/src/array_any_match.rs b/datafusion/functions-nested/src/array_any_match.rs index c8ba978881394..8e6e67ed2a17e 100644 --- a/datafusion/functions-nested/src/array_any_match.rs +++ b/datafusion/functions-nested/src/array_any_match.rs @@ -37,6 +37,8 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use std::{fmt::Debug, sync::Arc}; +use crate::lambda_utils::coerce_single_list_arg; + make_higher_order_function_expr_and_func!( ArrayAnyMatch, array_any_match, @@ -120,30 +122,7 @@ impl HigherOrderUDFImpl for ArrayAnyMatch { } fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { - let [list] = arg_types else { - return plan_err!( - "{} function requires 1 value argument, got {}", - self.name(), - arg_types.len() - ); - }; - - let coerced = match list { - DataType::List(_) | DataType::LargeList(_) => list.clone(), - DataType::ListView(field) | DataType::FixedSizeList(field, _) => { - DataType::List(Arc::clone(field)) - } - DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), - _ => { - return plan_err!( - "{} expected a list as first argument, got {}", - self.name(), - list - ); - } - }; - - Ok(vec![coerced]) + coerce_single_list_arg(self.name(), arg_types) } fn lambda_parameters( diff --git a/datafusion/functions-nested/src/lambda_utils.rs b/datafusion/functions-nested/src/lambda_utils.rs index 0f208ce5d26b2..ce596d0f1b38a 100644 --- a/datafusion/functions-nested/src/lambda_utils.rs +++ b/datafusion/functions-nested/src/lambda_utils.rs @@ -65,6 +65,7 @@ pub(crate) fn coerce_single_list_arg( DataType::List(Arc::clone(field)) } DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), + DataType::Null => DataType::new_list(DataType::Null, true), _ => return plan_err!("{name} expected a list as first argument, got {list}"), }; diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index aab0d013a4152..45dc4b6f10a0b 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -105,6 +105,11 @@ impl ScalarUDFImpl for MapExtract { fn return_type(&self, arg_types: &[DataType]) -> Result { let [map_type, _] = take_function_args(self.name(), arg_types)?; + + if map_type.is_null() { + return Ok(DataType::Null); + } + let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), @@ -123,6 +128,10 @@ impl ScalarUDFImpl for MapExtract { fn coerce_types(&self, arg_types: &[DataType]) -> Result> { let [map_type, _] = take_function_args(self.name(), arg_types)?; + if map_type.is_null() { + return Ok(arg_types.to_vec()); + } + let field = get_map_entry_field(map_type)?; Ok(vec![ map_type.clone(), @@ -185,6 +194,7 @@ fn map_extract_inner(args: &[ArrayRef]) -> Result { let map_array = match map_arg.data_type() { DataType::Map(_, _) => as_map_array(&map_arg)?, + DataType::Null => return Ok(Arc::clone(map_arg)), _ => return exec_err!("The first argument in map_extract must be a map"), }; diff --git a/datafusion/spark/src/function/array/repeat.rs b/datafusion/spark/src/function/array/repeat.rs index da9b19a768680..6effdf9a50f9a 100644 --- a/datafusion/spark/src/function/array/repeat.rs +++ b/datafusion/spark/src/function/array/repeat.rs @@ -74,9 +74,11 @@ impl ScalarUDFImpl for SparkArrayRepeat { // Coerce the second argument to Int64/UInt64 if it's a numeric type let second = match second_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Null => DataType::Int64, DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { DataType::UInt64 } diff --git a/datafusion/sqllogictest/test_files/array/array_any_match.slt b/datafusion/sqllogictest/test_files/array/array_any_match.slt index 27f2a5339ef68..37aa47c55adcf 100644 --- a/datafusion/sqllogictest/test_files/array/array_any_match.slt +++ b/datafusion/sqllogictest/test_files/array/array_any_match.slt @@ -103,6 +103,12 @@ SELECT list_any_match([1, 2, 3], x -> x > 2); ---- true +# null arg +query B +SELECT array_any_match(NULL, x -> x > 2); +---- +NULL + statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/array/array_filter.slt b/datafusion/sqllogictest/test_files/array/array_filter.slt index f22cfb219830c..9b564c5061205 100644 --- a/datafusion/sqllogictest/test_files/array/array_filter.slt +++ b/datafusion/sqllogictest/test_files/array/array_filter.slt @@ -204,6 +204,12 @@ SELECT array_transform(array_filter(list, v -> v > 1), v -> v * 3) FROM with_nul [6] NULL +# null arg +query ? +SELECT array_filter(NULL, x -> x > 2); +---- +NULL + statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/array/array_transform.slt b/datafusion/sqllogictest/test_files/array/array_transform.slt index c8c43588c882c..5439d7441155b 100644 --- a/datafusion/sqllogictest/test_files/array/array_transform.slt +++ b/datafusion/sqllogictest/test_files/array/array_transform.slt @@ -393,6 +393,12 @@ physical_plan 02)--ProjectionExec: expr=[text@0 as text, list@1 as list, number@2 as number, CASE WHEN number@2 > 30 THEN array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + v@5 + array_element(list@4, 1)))) ELSE array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + array_element(list@4, 1)))) END as CASE WHEN t.number > Int64(30) THEN array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + v + list[Int64(1)]))) ELSE array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + list[Int64(1)]))) END] 03)----DataSourceExec: partitions=1, partition_sizes=[1] +# null arg +query ? +SELECT array_transform(NULL, x -> x * 2); +---- +NULL + query error select array_transform(); ---- diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 2b390c3748e35..6b9957c33a325 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -642,6 +642,12 @@ select map_extract(MAP {1: 1, 2: 2, 3:3}, '1'), map_extract(MAP {1: 1, 2: 2, 3:3 ---- [1] [1] [1] [NULL] [1] +# null arg +query ? +select map_extract(NULL, 'a'); +---- +NULL + # map_extract with columns query ??? select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_1; diff --git a/datafusion/sqllogictest/test_files/spark/string/repeat.slt b/datafusion/sqllogictest/test_files/spark/string/repeat.slt index 5ca4166f9f4e0..d4dcce190ddda 100644 --- a/datafusion/sqllogictest/test_files/spark/string/repeat.slt +++ b/datafusion/sqllogictest/test_files/spark/string/repeat.slt @@ -25,3 +25,9 @@ ## PySpark 3.5.5 Result: {'repeat(123, 2)': '123123', 'typeof(repeat(123, 2))': 'string', 'typeof(123)': 'string', 'typeof(2)': 'int'} #query #SELECT repeat('123'::string, 2::int); + +# null count +query T +select repeat('a', NULL); +---- +NULL