Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/source/user-guide/latest/datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ the tables below and may be reconsidered based on demand:

## Interval

Interval types fall back to Spark today. Native acceleration is tracked by
Interval type support is incremental and tracked by
[#4540](https://github.com/apache/datafusion-comet/issues/4540).

| Type | Status | Notes |
| ----------------------- | ------ | ----------------- |
| `YearMonthIntervalType` | 🔜 | Tracked by #4540. |
| `DayTimeIntervalType` | 🔜 | Tracked by #4540. |
| `CalendarIntervalType` | 🔜 | Tracked by #4540. |
| Type | Status | Notes |
| ----------------------- | ------ | ----------------------------------------------------------------------- |
| `YearMonthIntervalType` | | Supported for `make_ym_interval` and YearMonth interval multiplication. |
| `DayTimeIntervalType` | 🔜 | Tracked by #4540. |
| `CalendarIntervalType` | 🔜 | Tracked by #4540. |

## Complex

Expand Down
2 changes: 1 addition & 1 deletion docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ The type-name conversion functions (`bigint`, `binary`, `boolean`, `date`, `deci
| `make_timestamp` | ✅ | |
| `make_timestamp_ltz` | ✅ | 2-arg TIME form falls back |
| `make_timestamp_ntz` | ✅ | 2-arg TIME form falls back |
| `make_ym_interval` | 🔜 | [#4541](https://github.com/apache/datafusion-comet/issues/4541) |
| `make_ym_interval` | | Routes through the JVM codegen dispatcher |
| `minute` | ✅ | |
| `month` | ✅ | |
| `monthname` | ✅ | Abbreviated month name (Spark 4.0+) |
Expand Down
20 changes: 16 additions & 4 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ use crate::execution::{
};
use crate::jvm_bridge::{jni_call, JVMClasses};
use arrow::compute::CastOptions;
use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION};
use arrow::datatypes::{
DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, DECIMAL128_MAX_PRECISION,
};
use arrow::ffi_stream::FFI_ArrowArrayStream;
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::count::count_udaf;
Expand Down Expand Up @@ -101,8 +103,8 @@ use datafusion::physical_expr::LexOrdering;
use crate::parquet::parquet_exec::init_datasource_exec;
use arrow::array::{
new_empty_array, Array, ArrayRef, BinaryBuilder, BooleanArray, Date32Array, Decimal128Array,
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray,
NullArray, StringBuilder, TimestampMicrosecondArray,
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
IntervalYearMonthArray, ListArray, NullArray, StringBuilder, TimestampMicrosecondArray,
};
use arrow::buffer::{BooleanBuffer, NullBuffer, OffsetBuffer};
use arrow::row::{OwnedRow, RowConverter, SortField};
Expand Down Expand Up @@ -362,6 +364,9 @@ impl PhysicalPlanner {
DataType::Time64(TimeUnit::Nanosecond) => {
ScalarValue::Time64Nanosecond(None)
}
DataType::Interval(IntervalUnit::YearMonth) => {
ScalarValue::IntervalYearMonth(None)
}
dt => {
return Err(GeneralError(format!("{dt:?} is not supported in Comet")))
}
Expand All @@ -374,9 +379,12 @@ impl PhysicalPlanner {
Value::IntVal(value) => match data_type {
DataType::Int32 => ScalarValue::Int32(Some(*value)),
DataType::Date32 => ScalarValue::Date32(Some(*value)),
DataType::Interval(IntervalUnit::YearMonth) => {
ScalarValue::IntervalYearMonth(Some(*value))
}
dt => {
return Err(GeneralError(format!(
"Expected either 'Int32' or 'Date32' for IntVal, but found {dt:?}"
"Expected either 'Int32', 'Date32', or 'Interval(YearMonth)' for IntVal, but found {dt:?}"
)))
}
},
Expand Down Expand Up @@ -3892,6 +3900,10 @@ fn literal_to_array_ref(
list_literal.int_values.into(),
Some(nulls.clone().into()),
))),
DataType::Interval(IntervalUnit::YearMonth) => Ok(Arc::new(IntervalYearMonthArray::new(
list_literal.int_values.into(),
Some(nulls.clone().into()),
))),
DataType::Timestamp(TimeUnit::Microsecond, None) => {
Ok(Arc::new(TimestampMicrosecondArray::new(
list_literal.long_values.into(),
Expand Down
3 changes: 2 additions & 1 deletion native/core/src/execution/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use super::operators::ExecutionError;
use crate::errors::ExpressionError;
use arrow::datatypes::{DataType as ArrowDataType, TimeUnit};
use arrow::datatypes::{DataType as ArrowDataType, IntervalUnit, TimeUnit};
use arrow::datatypes::{Field, Fields};
use datafusion_comet_proto::{
spark_config, spark_expression,
Expand Down Expand Up @@ -97,6 +97,7 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType {
DataTypeId::TimestampNtz => ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
DataTypeId::Date => ArrowDataType::Date32,
DataTypeId::Time => ArrowDataType::Time64(TimeUnit::Nanosecond),
DataTypeId::YearMonthInterval => ArrowDataType::Interval(IntervalUnit::YearMonth),
DataTypeId::Null => ArrowDataType::Null,
DataTypeId::List => match dt_value
.type_info
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ message DataType {
MAP = 15;
STRUCT = 16;
TIME = 17;
YEAR_MONTH_INTERVAL = 18;
}
DataTypeId type_id = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
case "TinyIntVector" => classOf[TinyIntVector]
case "SmallIntVector" => classOf[SmallIntVector]
case "IntVector" => classOf[IntVector]
case "IntervalYearVector" => classOf[IntervalYearVector]
case "BigIntVector" => classOf[BigIntVector]
case "Float4Vector" => classOf[Float4Vector]
case "Float8Vector" => classOf[Float8Vector]
Expand All @@ -82,6 +83,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
case BooleanType | ByteType | ShortType | IntegerType | LongType => true
case FloatType | DoubleType => true
case _: DecimalType => true
case _: YearMonthIntervalType => true
case _: StringType | _: BinaryType => true
case DateType | TimestampType | TimestampNTZType => true
case ArrayType(inner, _) => isSupportedDataType(inner)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ private[codegen] object CometBatchKernelCodegenInput {
classOf[TinyIntVector],
classOf[SmallIntVector],
classOf[IntVector],
classOf[IntervalYearVector],
classOf[BigIntVector],
classOf[Float4Vector],
classOf[Float8Vector],
Expand Down Expand Up @@ -127,7 +128,9 @@ private[codegen] object CometBatchKernelCodegenInput {
}
val intCases = withOrd.collect {
case (ArrowColumnSpec(cls, _), ord)
if cls == classOf[IntVector] || cls == classOf[DateDayVector] =>
if cls == classOf[IntVector] ||
cls == classOf[DateDayVector] ||
cls == classOf[IntervalYearVector] =>
s" case $ord: return this.col$ord.getInt(this.rowIdx);"
}
val longCases = withOrd.collect {
Expand Down Expand Up @@ -590,7 +593,7 @@ private[codegen] object CometBatchKernelCodegenInput {
case BooleanType => s"getBoolean($idx)"
case ByteType => s"getByte($idx)"
case ShortType => s"getShort($idx)"
case IntegerType | DateType => s"getInt($idx)"
case IntegerType | DateType | _: YearMonthIntervalType => s"getInt($idx)"
case LongType | TimestampType | TimestampNTZType => s"getLong($idx)"
case FloatType => s"getFloat($idx)"
case DoubleType => s"getDouble($idx)"
Expand Down Expand Up @@ -687,7 +690,7 @@ private[codegen] object CometBatchKernelCodegenInput {
| public short getShort(int i) {
| return $childField.getShort(startIndex + i);
| }""".stripMargin
case IntegerType | DateType =>
case IntegerType | DateType | _: YearMonthIntervalType =>
s""" @Override
| public int getInt(int i) {
| return $childField.getInt(startIndex + i);
Expand Down Expand Up @@ -843,7 +846,7 @@ private[codegen] object CometBatchKernelCodegenInput {
s" case $fi: return ${path}_f$fi.getByte(this.rowIdx);"
case ShortType =>
s" case $fi: return ${path}_f$fi.getShort(this.rowIdx);"
case IntegerType | DateType =>
case IntegerType | DateType | _: YearMonthIntervalType =>
s" case $fi: return ${path}_f$fi.getInt(this.rowIdx);"
case LongType | TimestampType | TimestampNTZType =>
s" case $fi: return ${path}_f$fi.getLong(this.rowIdx);"
Expand Down Expand Up @@ -891,8 +894,11 @@ private[codegen] object CometBatchKernelCodegenInput {
fieldReadScalar(fi, ShortType, f.nullable)
}
val intCases = scalarOrd.collect {
case (f, fi) if f.sparkType == IntegerType || f.sparkType == DateType =>
fieldReadScalar(fi, IntegerType, f.nullable)
case (f, fi)
if f.sparkType == IntegerType ||
f.sparkType == DateType ||
f.sparkType.isInstanceOf[YearMonthIntervalType] =>
fieldReadScalar(fi, f.sparkType, f.nullable)
}
val longCases = scalarOrd.collect {
case (f, fi)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ private[codegen] object CometBatchKernelCodegenOutput {
case ByteType => classOf[TinyIntVector].getName
case ShortType => classOf[SmallIntVector].getName
case IntegerType => classOf[IntVector].getName
case _: YearMonthIntervalType => classOf[IntervalYearVector].getName
case LongType => classOf[BigIntVector].getName
case FloatType => classOf[Float4Vector].getName
case DoubleType => classOf[Float8Vector].getName
Expand Down Expand Up @@ -208,7 +209,7 @@ private[codegen] object CometBatchKernelCodegenOutput {
val set = if (nested) "setSafe" else "set"
OutputEmit("", s"$targetVec.$set($idx, $source ? 1 : 0);")
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType |
TimestampType | TimestampNTZType =>
TimestampType | TimestampNTZType | _: YearMonthIntervalType =>
// Spark codegen emits the matching primitive Java type; Arrow `set` overloads accept it.
val set = if (nested) "setSafe" else "set"
OutputEmit("", s"$targetVec.$set($idx, $source);")
Expand Down Expand Up @@ -392,7 +393,7 @@ private[codegen] object CometBatchKernelCodegenOutput {
case BooleanType => s"$target.getBoolean($idx)"
case ByteType => s"$target.getByte($idx)"
case ShortType => s"$target.getShort($idx)"
case IntegerType | DateType => s"$target.getInt($idx)"
case IntegerType | DateType | _: YearMonthIntervalType => s"$target.getInt($idx)"
case LongType | TimestampType | TimestampNTZType => s"$target.getLong($idx)"
case FloatType => s"$target.getFloat($idx)"
case DoubleType => s"$target.getDouble($idx)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,11 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[Hour] -> CometHour,
classOf[MakeDate] -> CometMakeDate,
classOf[MakeTimestamp] -> CometMakeTimestamp,
classOf[MakeYMInterval] -> CometMakeYMInterval,
classOf[MicrosToTimestamp] -> CometMicrosToTimestamp,
classOf[MillisToTimestamp] -> CometMillisToTimestamp,
classOf[MonthsBetween] -> CometMonthsBetween,
classOf[MultiplyYMInterval] -> CometMultiplyYMInterval,
classOf[Minute] -> CometMinute,
classOf[NextDay] -> CometNextDay,
classOf[Second] -> CometSecond,
Expand Down Expand Up @@ -478,7 +480,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
_: DecimalType | _: DateType | _: BooleanType | _: NullType =>
_: DecimalType | _: DateType | _: BooleanType | _: NullType | _: YearMonthIntervalType =>
true
case dt if isTimeType(dt) =>
true
Expand Down Expand Up @@ -517,6 +519,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
case _: MapType => 15
case _: StructType => 16
case dt if isTimeType(dt) => 17
case _: YearMonthIntervalType => 18
case dt =>
logWarning(s"Cannot serialize Spark data type: $dt")
return None
Expand Down
6 changes: 5 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, FromUTCTimestamp, GetDateField, GetTimestamp, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, FromUTCTimestamp, GetDateField, GetTimestamp, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MakeYMInterval, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, MultiplyYMInterval, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -870,6 +870,10 @@ object CometMonthsBetween extends CometCodegenDispatch[MonthsBetween]

object CometMakeTimestamp extends CometCodegenDispatch[MakeTimestamp]

object CometMakeYMInterval extends CometCodegenDispatch[MakeYMInterval]

object CometMultiplyYMInterval extends CometCodegenDispatch[MultiplyYMInterval]

object CometMicrosToTimestamp extends CometCodegenDispatch[MicrosToTimestamp]

object CometMillisToTimestamp extends CometCodegenDispatch[MillisToTimestamp]
Expand Down
7 changes: 4 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.lang
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, TimestampNTZType, TimestampType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.UTF8String

import com.google.protobuf.ByteString
Expand Down Expand Up @@ -77,7 +77,8 @@ object CometLiteral extends CometExpressionSerde[Literal] with Logging {
case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte])
case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short])
case _: IntegerType | _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
case _: IntegerType | _: DateType | _: YearMonthIntervalType =>
exprBuilder.setIntVal(value.asInstanceOf[Int])
case _: LongType | _: TimestampType | _: TimestampNTZType =>
exprBuilder.setLongVal(value.asInstanceOf[Long])
case dt if isTimeType(dt) =>
Expand Down Expand Up @@ -150,7 +151,7 @@ object CometLiteral extends CometExpressionSerde[Literal] with Logging {
else null.asInstanceOf[Integer])
listLiteralBuilder.addNullMask(casted != null)
})
case IntegerType | DateType =>
case IntegerType | DateType | _: YearMonthIntervalType =>
array.foreach(v => {
val casted = v.asInstanceOf[Integer]
listLiteralBuilder.addIntValues(casted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ class CometScalaUDFCodegen extends CometUDF with Logging {
child = specFor(childVec))
}
StructColumnSpec(nullable = true, fieldSpecs)
case _: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector |
_: Float4Vector | _: Float8Vector | _: DecimalVector | _: VarCharVector |
_: VarBinaryVector | _: DateDayVector | _: TimeStampMicroVector |
_: TimeStampMicroTZVector =>
case _: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
_: IntervalYearVector | _: BigIntVector | _: Float4Vector | _: Float8Vector |
_: DecimalVector | _: VarCharVector | _: VarBinaryVector | _: DateDayVector |
_: TimeStampMicroVector | _: TimeStampMicroTZVector =>
ScalarColumnSpec(v.getClass.asInstanceOf[Class[_ <: ValueVector]], nullable = true)
case other =>
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ object Utils extends CometTypeShim with Logging {
case NullType => ArrowType.Null.INSTANCE
case dt if isTimeType(dt) =>
new ArrowType.Time(TimeUnit.NANOSECOND, 64)
case _: YearMonthIntervalType =>
new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _ =>
throw new UnsupportedOperationException(
s"Unsupported data type: [${dt.getClass.getName}] ${dt.catalogString}")
Expand Down
Loading
Loading