Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions native/core/src/execution/operators/parquet_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ impl ParquetWriter {
pub struct ParquetWriterExec {
/// Input execution plan
input: Arc<dyn ExecutionPlan>,
/// Output file path (final destination)
/// Output directory path (final destination), used for display/debugging
output_path: String,
/// Working directory for temporary files (used by FileCommitProtocol)
work_dir: String,
/// Full file path to write for this Spark task attempt
task_output_path: String,
/// Job ID for tracking this write operation
job_id: Option<String>,
/// Task attempt ID for this specific task
Expand All @@ -221,7 +221,7 @@ impl ParquetWriterExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
output_path: String,
work_dir: String,
task_output_path: String,
job_id: Option<String>,
task_attempt_id: Option<i32>,
compression: CompressionCodec,
Expand All @@ -242,7 +242,7 @@ impl ParquetWriterExec {
Ok(ParquetWriterExec {
input,
output_path,
work_dir,
task_output_path,
job_id,
task_attempt_id,
compression,
Expand Down Expand Up @@ -431,7 +431,7 @@ impl ExecutionPlan for ParquetWriterExec {
1 => Ok(Arc::new(ParquetWriterExec::try_new(
Arc::clone(&children[0]),
self.output_path.clone(),
self.work_dir.clone(),
self.task_output_path.clone(),
self.job_id.clone(),
self.task_attempt_id,
self.compression.clone(),
Expand Down Expand Up @@ -460,8 +460,7 @@ impl ExecutionPlan for ParquetWriterExec {
let runtime_env = context.runtime_env();
let input = self.input.execute(partition, context)?;
let input_schema = self.input.schema();
let work_dir = self.work_dir.clone();
let task_attempt_id = self.task_attempt_id;
let part_file = self.task_output_path.clone();
let compression = self.compression_to_parquet()?;
let column_names = self.column_names.clone();

Expand All @@ -476,16 +475,8 @@ impl ExecutionPlan for ParquetWriterExec {
.collect();
let output_schema = Arc::new(arrow::datatypes::Schema::new(fields));

// Generate part file name for this partition
// If using FileCommitProtocol (work_dir is set), include task_attempt_id in the filename
let part_file = if let Some(attempt_id) = task_attempt_id {
format!(
"{}/part-{:05}-{:05}.parquet",
work_dir, self.partition_id, attempt_id
)
} else {
format!("{}/part-{:05}.parquet", work_dir, self.partition_id)
};
// The path is generated by Spark's FileCommitProtocol. Writing exactly to this
// path allows Spark to commit or abort the task using its standard protocol.

// Configure writer properties
let props = WriterProperties::builder()
Expand Down Expand Up @@ -809,14 +800,16 @@ mod tests {
let memory_exec = Arc::new(DataSourceExec::new(Arc::new(memory_source_config)));

// Create ParquetWriterExec with DataSourceExec as input
let output_path = "unused".to_string();
let work_dir = "hdfs://namenode:9000/user/test_parquet_writer_exec".to_string();
let output_path = "hdfs://namenode:9000/user/test_parquet_writer_exec".to_string();
let task_output_path =
"hdfs://namenode:9000/user/test_parquet_writer_exec/part-00000-c000.parquet"
.to_string();
let column_names = vec!["id".to_string(), "name".to_string()];

let parquet_writer = ParquetWriterExec::try_new(
memory_exec,
output_path,
work_dir,
task_output_path,
None, // job_id
Some(123), // task_attempt_id
CompressionCodec::None,
Expand Down
26 changes: 21 additions & 5 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1662,14 +1662,30 @@ impl PhysicalPlanner {
.map(|(k, v)| (k.clone(), v.clone()))
.collect();

let task_output_path = writer
.task_output_path
.clone()
.or_else(|| {
// Compatibility with older serialized plans that only provided a work
// directory. New JVM code always sends task_output_path from Spark's
// FileCommitProtocol so native code does not invent filenames.
writer
.work_dir
.as_ref()
.map(|work_dir| match writer.task_attempt_id {
Some(attempt_id) => format!(
"{}/part-{:05}-{:05}.parquet",
work_dir, self.partition, attempt_id
),
None => format!("{}/part-{:05}.parquet", work_dir, self.partition),
})
})
.expect("task_output_path is provided");

let parquet_writer = Arc::new(ParquetWriterExec::try_new(
Arc::clone(&child.native_plan),
writer.output_path.clone(),
writer
.work_dir
.as_ref()
.expect("work_dir is provided")
.clone(),
task_output_path,
writer.job_id.clone(),
writer.task_attempt_id,
codec,
Expand Down
6 changes: 4 additions & 2 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ message ParquetWriter {
string output_path = 1;
CompressionCodec compression = 2;
repeated string column_names = 4;
// Working directory for temporary files (used by FileCommitProtocol)
// If not set, files are written directly to output_path
// Working directory for temporary files (legacy path used before task_output_path existed).
optional string work_dir = 5;
// Job ID for tracking this write operation
optional string job_id = 6;
Expand All @@ -361,6 +360,9 @@ message ParquetWriter {
// configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in
// the map.
map<string, string> object_store_options = 8;
// Full temporary output file path returned by Spark's FileCommitProtocol for this task.
// The native writer must write exactly to this path so Spark can commit or abort it.
optional string task_output_path = 9;
}

enum AggregateMode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
package org.apache.comet.serde.operator

import java.net.URI
import java.util.Locale
import java.util.{Locale, UUID}

import scala.jdk.CollectionConverters._

import org.apache.spark.SparkException
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.SerializableConfiguration

import org.apache.comet.{CometConf, ConfigEntry}
import org.apache.comet.CometSparkSessionExtensions.withFallbackReason
Expand Down Expand Up @@ -132,8 +138,8 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
.setOutputPath(outputPath)
.setCompression(codec)
.addAllColumnNames(cmd.query.output.map(_.name).asJava)
// Note: work_dir, job_id, and task_attempt_id will be set at execution time
// in CometNativeWriteExec, as they depend on the Spark task context
// The task_output_path is filled in by CometNativeWriteExec at execution time from
// Spark's FileCommitProtocol, because it depends on the Spark task context.

// Collect S3/cloud storage configurations
val session = op.session
Expand Down Expand Up @@ -178,29 +184,35 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
other
}

// Create FileCommitProtocol for atomic writes
val jobId = java.util.UUID.randomUUID().toString
val committer =
try {
// Use Spark's SQLHadoopMapReduceCommitProtocol
val committerClass =
classOf[org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol]
val constructor =
committerClass.getConstructor(classOf[String], classOf[String], classOf[Boolean])
Some(
constructor
.newInstance(
jobId,
outputPath,
java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for now
)
.asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol])
} catch {
case e: Exception =>
throw new SparkException(s"Could not instantiate FileCommitProtocol: ${e.getMessage}")
}

CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId)
val session = op.session
val hadoopConf = session.sessionState.newHadoopConfWithOptions(cmd.options)
val job = Job.getInstance(hadoopConf)
job.setOutputKeyClass(classOf[Void])
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputPath))

val outputWriterFactory =
cmd.fileFormat.prepareWrite(session, job, CaseInsensitiveMap(cmd.options), cmd.query.schema)

val commitProtocolJobId = UUID.randomUUID().toString
val committer = FileCommitProtocol.instantiate(
session.sessionState.conf.fileCommitProtocolClass,
commitProtocolJobId,
outputPath,
false)

// Match Spark's FileFormatWriter behavior: propagate a per-write UUID in the Hadoop
// configuration before it is serialized to executors.
job.getConfiguration.set("spark.sql.sources.writeJobUUID", UUID.randomUUID().toString)

val commitProtocol = CometNativeWriteExec.CommitProtocolConfig(
committer = committer,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory,
commitProtocolJobId = commitProtocolJobId,
jobTrackerID = CometNativeWriteExec.newJobTrackerID())

CometNativeWriteExec(nativeOp, childPlan, outputPath, Some(commitProtocol))
}

private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = {
Expand Down
Loading
Loading