diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bc4d0a9f2..bac6bf5d5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import com.google.common.collect.Iterables; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -57,6 +58,7 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -348,10 +350,17 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { .filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID)) .collect(Collectors.toList()); + // get LITERAL_AGG() function calls — injected by SubQueryRemoveRule (CALCITE-6945) as a + // null-presence indicator; they carry a RexLiteral in rexList and have no Substrait binding. + List literalAggCalls = + aggregate.getAggCallList().stream() + .filter(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG) + .collect(Collectors.toList()); + List filteredAggCalls = aggregate.getAggCallList().stream() - // remove GROUP_ID() function calls - .filter(c -> !groupIdCalls.contains(c)) + // remove GROUP_ID() and LITERAL_AGG() function calls + .filter(c -> !groupIdCalls.contains(c) && !literalAggCalls.contains(c)) .collect(Collectors.toList()); List aggCalls = @@ -388,6 +397,8 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount); } else if (groupIdCalls.contains(aggCall)) { remap.add(i + groupingFieldCount, groupingSetIndex); + } else if (literalAggCalls.contains(aggCall)) { + // LITERAL_AGG handled below via Project wrapper — skip remap slot for now } else { // this should never get triggered throw new IllegalStateException( @@ -400,7 +411,58 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { } } - return builder.build(); + Rel aggRel = builder.build(); + + if (literalAggCalls.isEmpty()) { + return aggRel; + } + + if (groupings.size() > 1) { + throw new UnsupportedOperationException( + "LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported"); + } + + // Wrap the aggregate in a Project that replaces LITERAL_AGG output positions with their + // literal values and passes through all other fields via FieldReference. + // + // The aggregate output schema is: [grouping fields..., real agg measures...] + // The full output schema requested is: [grouping fields..., all agg calls (in original order)] + // For each position in the original agg call list: + // - real measure → FieldReference into the aggregate output + // - LITERAL_AGG → the literal value from aggCall.rexList + final int groupingFieldCount = + Math.toIntExact( + groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); + final int realAggCount = aggCalls.size(); + final int totalAggOutputFields = groupingFieldCount + realAggCount; + + // Build the project expression list: grouping fields first, then one expression per original + // agg call in declaration order. + List projectExprs = new ArrayList<>(); + for (int i = 0; i < groupingFieldCount; i++) { + projectExprs.add(FieldReference.newRootStructReference(i, aggRel.getRecordType())); + } + int realAggIndex = groupingFieldCount; // tracks next real-measure field index in aggRel output + for (AggregateCall aggCall : aggregate.getAggCallList()) { + if (literalAggCalls.contains(aggCall)) { + // Convert the RexLiteral stored in rexList to a Substrait literal expression + RexNode rexLiteral = Iterables.getOnlyElement(aggCall.rexList); + projectExprs.add(toExpression(rexLiteral)); + } else if (!groupIdCalls.contains(aggCall)) { + // real measure: pass through by reference + projectExprs.add( + FieldReference.newRootStructReference(realAggIndex, aggRel.getRecordType())); + realAggIndex++; + } + // GROUP_ID calls are not present in the outer schema here (groupings.size() <= 1 branch); + // if groupings.size() > 1 they are handled by the remap above and should not appear here + } + + return Project.builder() + .remap(Remap.offset(totalAggOutputFields, projectExprs.size())) + .expressions(projectExprs) + .input(aggRel) + .build(); } Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 8f392aae2..9c14780d2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -1,18 +1,24 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; import java.io.IOException; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql2rel.RelDecorrelator; import org.junit.jupiter.api.Test; class OptimizerIntegrationTest extends PlanTestBase { @@ -48,4 +54,62 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE // Conversion of the new plan should succeed SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION)); } + + /** + * Regression test for LITERAL_AGG handling in SubstraitRelVisitor. + * + *

Calcite's SubQueryRemoveRule (CALCITE-6945, landed in 1.38.0) rewrites correlated quantified + * comparisons (e.g. {@code <> SOME}) using {@code LITERAL_AGG(true)} as a null-presence + * indicator. SubstraitRelVisitor has no Substrait binding for {@code LITERAL_AGG}, so the + * conversion previously crashed with "UnsupportedOperationException: Unable to find binding for + * call LITERAL_AGG(true)". + * + * @see CALCITE-6945 PR + */ + @Test + void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule() + throws SqlParseException, IOException { + // <> SOME with a correlated nullable column triggers SubQueryRemoveRule's + // quantified-comparison path, which inserts LITERAL_AGG(true) into the aggregate. + String query = + "select e1.l_orderkey from lineitem e1 " + + "where e1.l_quantity <> some (" + + " select l_quantity from lineitem e2 where e2.l_partkey = e1.l_partkey" + + ")"; + + RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG); + + // Step 1 — SubQueryRemoveRule: rewrites RexSubQuery → LogicalCorrelate + LITERAL_AGG. + HepProgram subQueryProgram = + new HepProgramBuilder().addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE).build(); + HepPlanner hepPlanner = new HepPlanner(subQueryProgram); + hepPlanner.setRoot(relRoot.rel); + RelNode afterSubQueryRemove = hepPlanner.findBestExp(); + + // Step 2 — RelDecorrelator: rewrites LogicalCorrelate → LEFT JOIN; LITERAL_AGG survives in + // the aggregate as a synthetic null-presence indicator column. + RelNode decorrelated = + RelDecorrelator.decorrelateQuery( + afterSubQueryRemove, + RelFactories.LOGICAL_BUILDER.create(relRoot.rel.getCluster(), null)); + + // Conversion must succeed and produce the correct output schema. + // The query selects a single column (l_orderkey), so the plan root must expose 1 field. + // The LITERAL_AGG wrapper emits a Project on top of the aggregate; verify that structure. + io.substrait.plan.Plan.Root planRoot = + assertDoesNotThrow( + () -> + SubstraitRelVisitor.convert( + RelRoot.of(decorrelated, relRoot.kind), EXTENSION_COLLECTION)); + Rel result = planRoot.getInput(); + + // The outermost Rel visible to the caller is a Project that re-inserts the LITERAL_AGG + // literal and passes real measures through — it must expose exactly 1 output field + // (l_orderkey) matching the SELECT list. + assertInstanceOf(Project.class, result, "expected LITERAL_AGG wrapper Project at plan root"); + assertEquals( + 1, + result.getRecordType().fields().size(), + "output schema should have exactly 1 field (l_orderkey)"); + } }