Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use datafusion_functions_aggregate::expr_fn::{
array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum,
sum_distinct,
};
use datafusion_functions_nested::expr_fn::{array_filter, array_transform, make_array};
use datafusion_functions_nested::make_array::make_array_udf;
use datafusion_functions_window::expr_fn::{first_value, lead, row_number};
use insta::assert_snapshot;
Expand Down Expand Up @@ -78,8 +79,8 @@ use datafusion_expr::{
CreateMemoryTable, CreateView, DdlStatement, Expr, ExprFunctionExt, ExprSchemable,
LogicalPlan, LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, TableType,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, cast, col,
create_udf, exists, in_subquery, lit, out_ref_col, placeholder, scalar_subquery,
when, wildcard,
create_udf, exists, in_subquery, lambda, lambda_var, lit, out_ref_col, placeholder,
scalar_subquery, when, wildcard,
};
use datafusion_physical_expr::Partitioning;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
Expand Down Expand Up @@ -7242,3 +7243,22 @@ async fn test_grouping_with_alias() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_unresolved_lambda_variable() -> Result<()> {
table_with_mixed_lists().await?.with_column(
"c",
array_transform(
make_array(vec![col("list")]),
lambda(
["x"],
array_filter(
lambda_var("x"),
lambda(["y"], lambda_var("y").gt_eq(lit(1))),
),
),
),
)?;

Ok(())
}
27 changes: 3 additions & 24 deletions datafusion/functions-nested/src/array_any_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ use datafusion_expr::{
use datafusion_macros::user_doc;
use std::{fmt::Debug, sync::Arc};

use crate::lambda_utils::coerce_single_list_arg;

make_higher_order_function_expr_and_func!(
ArrayAnyMatch,
array_any_match,
Expand Down Expand Up @@ -120,30 +122,7 @@ impl HigherOrderUDFImpl for ArrayAnyMatch {
}

fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [list] = arg_types else {
return plan_err!(
"{} function requires 1 value argument, got {}",
self.name(),
arg_types.len()
);
};

let coerced = match list {
DataType::List(_) | DataType::LargeList(_) => list.clone(),
DataType::ListView(field) | DataType::FixedSizeList(field, _) => {
DataType::List(Arc::clone(field))
}
DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)),
_ => {
return plan_err!(
"{} expected a list as first argument, got {}",
self.name(),
list
);
}
};

Ok(vec![coerced])
coerce_single_list_arg(self.name(), arg_types)
}

fn lambda_parameters(
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-nested/src/lambda_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub(crate) fn coerce_single_list_arg(
DataType::List(Arc::clone(field))
}
DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)),
DataType::Null => DataType::new_list(DataType::Null, true),
_ => return plan_err!("{name} expected a list as first argument, got {list}"),
};

Expand Down
10 changes: 10 additions & 0 deletions datafusion/functions-nested/src/map_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ impl ScalarUDFImpl for MapExtract {

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let [map_type, _] = take_function_args(self.name(), arg_types)?;

if map_type.is_null() {
return Ok(DataType::Null);
}

let map_fields = get_map_entry_field(map_type)?;
Ok(DataType::List(Arc::new(Field::new_list_field(
map_fields.last().unwrap().data_type().clone(),
Expand All @@ -123,6 +128,10 @@ impl ScalarUDFImpl for MapExtract {
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [map_type, _] = take_function_args(self.name(), arg_types)?;

if map_type.is_null() {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Null is accepted in coercion, but return_type still calls get_map_entry_field on the same Null type. So
map_extract(NULL, 'a') still fails with an internal error.

We should also add a test for this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thank you. The only tested function array_filter return_type returned the first arg type without any checks, so it obviously passed 63ab393 Note that I also implemented null handling during execution

return Ok(arg_types.to_vec());
}

let field = get_map_entry_field(map_type)?;
Ok(vec![
map_type.clone(),
Expand Down Expand Up @@ -185,6 +194,7 @@ fn map_extract_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

let map_array = match map_arg.data_type() {
DataType::Map(_, _) => as_map_array(&map_arg)?,
DataType::Null => return Ok(Arc::clone(map_arg)),
_ => return exec_err!("The first argument in map_extract must be a map"),
};

Expand Down
8 changes: 5 additions & 3 deletions datafusion/spark/src/function/array/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ impl ScalarUDFImpl for SparkArrayRepeat {

// Coerce the second argument to Int64/UInt64 if it's a numeric type
let second = match second_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Int64
}
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Null => DataType::Int64,
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
DataType::UInt64
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/array/array_any_match.slt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ SELECT list_any_match([1, 2, 3], x -> x > 2);
----
true

# null arg
query B
SELECT array_any_match(NULL, x -> x > 2);
----
NULL

statement ok
drop table t;

Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/array/array_filter.slt
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ SELECT array_transform(array_filter(list, v -> v > 1), v -> v * 3) FROM with_nul
[6]
NULL

# null arg
query ?
SELECT array_filter(NULL, x -> x > 2);
----
NULL

statement ok
drop table t;

Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/array/array_transform.slt
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ physical_plan
02)--ProjectionExec: expr=[text@0 as text, list@1 as list, number@2 as number, CASE WHEN number@2 > 30 THEN array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + v@5 + array_element(list@4, 1)))) ELSE array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + array_element(list@4, 1)))) END as CASE WHEN t.number > Int64(30) THEN array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + v + list[Int64(1)]))) ELSE array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + list[Int64(1)]))) END]
03)----DataSourceExec: partitions=1, partition_sizes=[1]

# null arg
query ?
SELECT array_transform(NULL, x -> x * 2);
----
NULL

query error
select array_transform();
----
Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,12 @@ select map_extract(MAP {1: 1, 2: 2, 3:3}, '1'), map_extract(MAP {1: 1, 2: 2, 3:3
----
[1] [1] [1] [NULL] [1]

# null arg
query ?
select map_extract(NULL, 'a');
----
NULL

# map_extract with columns
query ???
select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_1;
Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/spark/string/repeat.slt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@
## PySpark 3.5.5 Result: {'repeat(123, 2)': '123123', 'typeof(repeat(123, 2))': 'string', 'typeof(123)': 'string', 'typeof(2)': 'int'}
#query
#SELECT repeat('123'::string, 2::int);

# null count
query T
select repeat('a', NULL);
----
NULL
Loading