Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import com.google.common.collect.Iterables;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
Expand Down Expand Up @@ -57,6 +58,7 @@
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;
Expand Down Expand Up @@ -348,10 +350,17 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
.filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID))
.collect(Collectors.toList());

// get LITERAL_AGG() function calls — injected by SubQueryRemoveRule (CALCITE-6945) as a
// null-presence indicator; they carry a RexLiteral in rexList and have no Substrait binding.
List<AggregateCall> literalAggCalls =
aggregate.getAggCallList().stream()
.filter(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG)
.collect(Collectors.toList());

List<AggregateCall> filteredAggCalls =
aggregate.getAggCallList().stream()
// remove GROUP_ID() function calls
.filter(c -> !groupIdCalls.contains(c))
// remove GROUP_ID() and LITERAL_AGG() function calls
.filter(c -> !groupIdCalls.contains(c) && !literalAggCalls.contains(c))
.collect(Collectors.toList());

List<Measure> aggCalls =
Expand Down Expand Up @@ -388,6 +397,8 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount);
} else if (groupIdCalls.contains(aggCall)) {
remap.add(i + groupingFieldCount, groupingSetIndex);
} else if (literalAggCalls.contains(aggCall)) {
// LITERAL_AGG handled below via Project wrapper — skip remap slot for now
} else {
// this should never get triggered
throw new IllegalStateException(
Expand All @@ -400,7 +411,58 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
}
}

return builder.build();
Rel aggRel = builder.build();

if (literalAggCalls.isEmpty()) {
return aggRel;
}

if (groupings.size() > 1) {
throw new UnsupportedOperationException(
"LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported");
}

// Wrap the aggregate in a Project that replaces LITERAL_AGG output positions with their
// literal values and passes through all other fields via FieldReference.
//
// The aggregate output schema is: [grouping fields..., real agg measures...]
// The full output schema requested is: [grouping fields..., all agg calls (in original order)]
// For each position in the original agg call list:
// - real measure → FieldReference into the aggregate output
// - LITERAL_AGG → the literal value from aggCall.rexList
final int groupingFieldCount =
Math.toIntExact(
groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count());
final int realAggCount = aggCalls.size();
final int totalAggOutputFields = groupingFieldCount + realAggCount;

// Build the project expression list: grouping fields first, then one expression per original
// agg call in declaration order.
List<Expression> projectExprs = new ArrayList<>();
for (int i = 0; i < groupingFieldCount; i++) {
projectExprs.add(FieldReference.newRootStructReference(i, aggRel.getRecordType()));
}
int realAggIndex = groupingFieldCount; // tracks next real-measure field index in aggRel output
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (literalAggCalls.contains(aggCall)) {
// Convert the RexLiteral stored in rexList to a Substrait literal expression
RexNode rexLiteral = Iterables.getOnlyElement(aggCall.rexList);
projectExprs.add(toExpression(rexLiteral));
} else if (!groupIdCalls.contains(aggCall)) {
// real measure: pass through by reference
projectExprs.add(
FieldReference.newRootStructReference(realAggIndex, aggRel.getRecordType()));
realAggIndex++;
}
// GROUP_ID calls are not present in the outer schema here (groupings.size() <= 1 branch);
// if groupings.size() > 1 they are handled by the remap above and should not appear here
}

return Project.builder()
.remap(Remap.offset(totalAggOutputFields, projectExprs.size()))
.expressions(projectExprs)
.input(aggRel)
.build();
}

Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import java.io.IOException;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql2rel.RelDecorrelator;
import org.junit.jupiter.api.Test;

class OptimizerIntegrationTest extends PlanTestBase {
Expand Down Expand Up @@ -48,4 +54,62 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE
// Conversion of the new plan should succeed
SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION));
}

/**
* Regression test for LITERAL_AGG handling in SubstraitRelVisitor.
*
* <p>Calcite's SubQueryRemoveRule (CALCITE-6945, landed in 1.38.0) rewrites correlated quantified
* comparisons (e.g. {@code <> SOME}) using {@code LITERAL_AGG(true)} as a null-presence
* indicator. SubstraitRelVisitor has no Substrait binding for {@code LITERAL_AGG}, so the
* conversion previously crashed with "UnsupportedOperationException: Unable to find binding for
* call LITERAL_AGG(true)".
*
* @see <a href="https://github.com/apache/calcite/pull/4296">CALCITE-6945 PR</a>
*/
@Test
void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule()
throws SqlParseException, IOException {
// <> SOME with a correlated nullable column triggers SubQueryRemoveRule's
// quantified-comparison path, which inserts LITERAL_AGG(true) into the aggregate.
String query =
"select e1.l_orderkey from lineitem e1 "
+ "where e1.l_quantity <> some ("
+ " select l_quantity from lineitem e2 where e2.l_partkey = e1.l_partkey"
+ ")";

RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG);

// Step 1 — SubQueryRemoveRule: rewrites RexSubQuery → LogicalCorrelate + LITERAL_AGG.
HepProgram subQueryProgram =
new HepProgramBuilder().addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE).build();
HepPlanner hepPlanner = new HepPlanner(subQueryProgram);
hepPlanner.setRoot(relRoot.rel);
RelNode afterSubQueryRemove = hepPlanner.findBestExp();

// Step 2 — RelDecorrelator: rewrites LogicalCorrelate → LEFT JOIN; LITERAL_AGG survives in
// the aggregate as a synthetic null-presence indicator column.
RelNode decorrelated =
RelDecorrelator.decorrelateQuery(
afterSubQueryRemove,
RelFactories.LOGICAL_BUILDER.create(relRoot.rel.getCluster(), null));

// Conversion must succeed and produce the correct output schema.
// The query selects a single column (l_orderkey), so the plan root must expose 1 field.
// The LITERAL_AGG wrapper emits a Project on top of the aggregate; verify that structure.
io.substrait.plan.Plan.Root planRoot =
assertDoesNotThrow(
() ->
SubstraitRelVisitor.convert(
RelRoot.of(decorrelated, relRoot.kind), EXTENSION_COLLECTION));
Rel result = planRoot.getInput();

// The outermost Rel visible to the caller is a Project that re-inserts the LITERAL_AGG
// literal and passes real measures through — it must expose exactly 1 output field
// (l_orderkey) matching the SELECT list.
assertInstanceOf(Project.class, result, "expected LITERAL_AGG wrapper Project at plan root");
assertEquals(
1,
result.getRecordType().fields().size(),
"output schema should have exactly 1 field (l_orderkey)");
}
}
Loading