From 3d58fa8130119bac8549dd99593faeb48546e332 Mon Sep 17 00:00:00 2001 From: peterxcli Date: Fri, 26 Jun 2026 19:56:14 +0800 Subject: [PATCH] feat(parquet): use Spark commit protocol for native writes Native Parquet writes now write to the task temp file returned by Spark's FileCommitProtocol, then commit or abort tasks through Spark's standard lifecycle. Pass the task output path through the native Parquet writer proto, wire commit protocol setup in CometNativeWriteExec, and assert committed output naming/cleanup in CometParquetWriterSuite. --- .../src/execution/operators/parquet_writer.rs | 35 +- native/core/src/execution/planner.rs | 26 +- native/proto/src/proto/operator.proto | 6 +- .../operator/CometDataWritingCommand.scala | 66 +-- .../sql/comet/CometNativeWriteExec.scala | 437 ++++++++++-------- .../parquet/CometParquetWriterSuite.scala | 23 + 6 files changed, 337 insertions(+), 256 deletions(-) diff --git a/native/core/src/execution/operators/parquet_writer.rs b/native/core/src/execution/operators/parquet_writer.rs index f1168c4a57..2777520abe 100644 --- a/native/core/src/execution/operators/parquet_writer.rs +++ b/native/core/src/execution/operators/parquet_writer.rs @@ -193,10 +193,10 @@ impl ParquetWriter { pub struct ParquetWriterExec { /// Input execution plan input: Arc, - /// Output file path (final destination) + /// Output directory path (final destination), used for display/debugging output_path: String, - /// Working directory for temporary files (used by FileCommitProtocol) - work_dir: String, + /// Full file path to write for this Spark task attempt + task_output_path: String, /// Job ID for tracking this write operation job_id: Option, /// Task attempt ID for this specific task @@ -221,7 +221,7 @@ impl ParquetWriterExec { pub fn try_new( input: Arc, output_path: String, - work_dir: String, + task_output_path: String, job_id: Option, task_attempt_id: Option, compression: CompressionCodec, @@ -242,7 +242,7 @@ impl ParquetWriterExec { Ok(ParquetWriterExec { input, output_path, - work_dir, + task_output_path, job_id, task_attempt_id, compression, @@ -431,7 +431,7 @@ impl ExecutionPlan for ParquetWriterExec { 1 => Ok(Arc::new(ParquetWriterExec::try_new( Arc::clone(&children[0]), self.output_path.clone(), - self.work_dir.clone(), + self.task_output_path.clone(), self.job_id.clone(), self.task_attempt_id, self.compression.clone(), @@ -460,8 +460,7 @@ impl ExecutionPlan for ParquetWriterExec { let runtime_env = context.runtime_env(); let input = self.input.execute(partition, context)?; let input_schema = self.input.schema(); - let work_dir = self.work_dir.clone(); - let task_attempt_id = self.task_attempt_id; + let part_file = self.task_output_path.clone(); let compression = self.compression_to_parquet()?; let column_names = self.column_names.clone(); @@ -476,16 +475,8 @@ impl ExecutionPlan for ParquetWriterExec { .collect(); let output_schema = Arc::new(arrow::datatypes::Schema::new(fields)); - // Generate part file name for this partition - // If using FileCommitProtocol (work_dir is set), include task_attempt_id in the filename - let part_file = if let Some(attempt_id) = task_attempt_id { - format!( - "{}/part-{:05}-{:05}.parquet", - work_dir, self.partition_id, attempt_id - ) - } else { - format!("{}/part-{:05}.parquet", work_dir, self.partition_id) - }; + // The path is generated by Spark's FileCommitProtocol. Writing exactly to this + // path allows Spark to commit or abort the task using its standard protocol. // Configure writer properties let props = WriterProperties::builder() @@ -809,14 +800,16 @@ mod tests { let memory_exec = Arc::new(DataSourceExec::new(Arc::new(memory_source_config))); // Create ParquetWriterExec with DataSourceExec as input - let output_path = "unused".to_string(); - let work_dir = "hdfs://namenode:9000/user/test_parquet_writer_exec".to_string(); + let output_path = "hdfs://namenode:9000/user/test_parquet_writer_exec".to_string(); + let task_output_path = + "hdfs://namenode:9000/user/test_parquet_writer_exec/part-00000-c000.parquet" + .to_string(); let column_names = vec!["id".to_string(), "name".to_string()]; let parquet_writer = ParquetWriterExec::try_new( memory_exec, output_path, - work_dir, + task_output_path, None, // job_id Some(123), // task_attempt_id CompressionCodec::None, diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 553fe5215c..2168090e00 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1662,14 +1662,30 @@ impl PhysicalPlanner { .map(|(k, v)| (k.clone(), v.clone())) .collect(); + let task_output_path = writer + .task_output_path + .clone() + .or_else(|| { + // Compatibility with older serialized plans that only provided a work + // directory. New JVM code always sends task_output_path from Spark's + // FileCommitProtocol so native code does not invent filenames. + writer + .work_dir + .as_ref() + .map(|work_dir| match writer.task_attempt_id { + Some(attempt_id) => format!( + "{}/part-{:05}-{:05}.parquet", + work_dir, self.partition, attempt_id + ), + None => format!("{}/part-{:05}.parquet", work_dir, self.partition), + }) + }) + .expect("task_output_path is provided"); + let parquet_writer = Arc::new(ParquetWriterExec::try_new( Arc::clone(&child.native_plan), writer.output_path.clone(), - writer - .work_dir - .as_ref() - .expect("work_dir is provided") - .clone(), + task_output_path, writer.job_id.clone(), writer.task_attempt_id, codec, diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 9d81d2853b..d957238b45 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -347,8 +347,7 @@ message ParquetWriter { string output_path = 1; CompressionCodec compression = 2; repeated string column_names = 4; - // Working directory for temporary files (used by FileCommitProtocol) - // If not set, files are written directly to output_path + // Working directory for temporary files (legacy path used before task_output_path existed). optional string work_dir = 5; // Job ID for tracking this write operation optional string job_id = 6; @@ -361,6 +360,9 @@ message ParquetWriter { // configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in // the map. map object_store_options = 8; + // Full temporary output file path returned by Spark's FileCommitProtocol for this task. + // The native writer must write exactly to this path so Spark can commit or abort it. + optional string task_output_path = 9; } enum AggregateMode { diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 4ae73565c6..0f1e4c6057 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -20,16 +20,22 @@ package org.apache.comet.serde.operator import java.net.URI -import java.util.Locale +import java.util.{Locale, UUID} import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkException +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.SerializableConfiguration import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.withFallbackReason @@ -132,8 +138,8 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec .setOutputPath(outputPath) .setCompression(codec) .addAllColumnNames(cmd.query.output.map(_.name).asJava) - // Note: work_dir, job_id, and task_attempt_id will be set at execution time - // in CometNativeWriteExec, as they depend on the Spark task context + // The task_output_path is filled in by CometNativeWriteExec at execution time from + // Spark's FileCommitProtocol, because it depends on the Spark task context. // Collect S3/cloud storage configurations val session = op.session @@ -178,29 +184,35 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec other } - // Create FileCommitProtocol for atomic writes - val jobId = java.util.UUID.randomUUID().toString - val committer = - try { - // Use Spark's SQLHadoopMapReduceCommitProtocol - val committerClass = - classOf[org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol] - val constructor = - committerClass.getConstructor(classOf[String], classOf[String], classOf[Boolean]) - Some( - constructor - .newInstance( - jobId, - outputPath, - java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for now - ) - .asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol]) - } catch { - case e: Exception => - throw new SparkException(s"Could not instantiate FileCommitProtocol: ${e.getMessage}") - } - - CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId) + val session = op.session + val hadoopConf = session.sessionState.newHadoopConfWithOptions(cmd.options) + val job = Job.getInstance(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, new Path(outputPath)) + + val outputWriterFactory = + cmd.fileFormat.prepareWrite(session, job, CaseInsensitiveMap(cmd.options), cmd.query.schema) + + val commitProtocolJobId = UUID.randomUUID().toString + val committer = FileCommitProtocol.instantiate( + session.sessionState.conf.fileCommitProtocolClass, + commitProtocolJobId, + outputPath, + false) + + // Match Spark's FileFormatWriter behavior: propagate a per-write UUID in the Hadoop + // configuration before it is serialized to executors. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", UUID.randomUUID().toString) + + val commitProtocol = CometNativeWriteExec.CommitProtocolConfig( + committer = committer, + serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), + outputWriterFactory = outputWriterFactory, + commitProtocolJobId = commitProtocolJobId, + jobTrackerID = CometNativeWriteExec.newJobTrackerID()) + + CometNativeWriteExec(nativeOp, childPlan, outputPath, Some(commitProtocol)) } private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 4cfdd11361..aa86a824f5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -19,72 +19,77 @@ package org.apache.spark.sql.comet -import scala.jdk.CollectionConverters._ +import java.util.Date -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.TaskContext -import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} +import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.CometNativeWriteExec.CommitProtocolConfig import org.apache.spark.sql.comet.execution.arrow.CometArrowStream import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.datasources.OutputWriterFactory import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils +import org.apache.spark.util.SerializableConfiguration import com.google.protobuf.CodedOutputStream import org.apache.comet.CometExecIterator import org.apache.comet.serde.OperatorOuterClass.Operator +object CometNativeWriteExec { + + def newJobTrackerID(): String = SparkHadoopWriterUtils.createJobTrackerID(new Date()) + + /** + * Driver-created objects required to use Spark's FileCommitProtocol from native write tasks. + * The committer instance is serializable by contract and is sent to executors, while the same + * driver-side instance receives task commit messages and commits or aborts the job. + */ + case class CommitProtocolConfig( + committer: FileCommitProtocol, + serializableHadoopConf: SerializableConfiguration, + outputWriterFactory: OutputWriterFactory, + commitProtocolJobId: String, + jobTrackerID: String) + extends Serializable +} + /** - * Comet physical operator for native Parquet write operations with FileCommitProtocol support. - * - * This operator writes data to Parquet files using the native Comet engine. It integrates with - * Spark's FileCommitProtocol to provide atomic writes with proper staging and commit semantics. + * Comet physical operator for native Parquet write operations. * - * The implementation includes support for Spark's file commit protocol through work_dir, job_id, - * and task_attempt_id parameters that can be set in the operator. When work_dir is set, files are - * written to a temporary location that can be atomically committed later. + * When [[commitProtocol]] is present, this operator follows Spark's FileCommitProtocol lifecycle: + * driver setupJob, executor setupTask/newTaskTempFile/commitTask or abortTask, and driver + * commitJob or abortJob. Native code writes exactly to the task temp file returned by Spark's + * commit protocol; it does not generate its own final part-file names. * * @param nativeOp - * The native operator representing the write operation (template, will be modified per task) + * The native operator representing the write operation (template, modified per task) * @param child * The child operator providing the data to write * @param outputPath - * The path where the Parquet file will be written - * @param committer - * FileCommitProtocol for atomic writes. If None, files are written directly. - * @param jobTrackerID - * Unique identifier for this write job + * The final output directory for the write + * @param commitProtocol + * Spark file commit protocol state. If absent, files are written directly under outputPath. */ case class CometNativeWriteExec( nativeOp: Operator, child: SparkPlan, outputPath: String, - committer: Option[FileCommitProtocol] = None, - jobTrackerID: String = Utils.createTempDir().getName) + commitProtocol: Option[CommitProtocolConfig] = None) extends CometNativeExec with UnaryExecNode { override def originalPlan: SparkPlan = child - // Accumulator to collect TaskCommitMessages from all tasks - // Must be eagerly initialized on driver, not lazy - @transient private val taskCommitMessagesAccum = - sparkContext.collectionAccumulator[FileCommitProtocol.TaskCommitMessage]("taskCommitMessages") - - override def serializedPlanOpt: SerializedPlan = { - val size = nativeOp.getSerializedSize - val bytes = new Array[Byte](size) - val codedOutput = CodedOutputStream.newInstance(bytes) - nativeOp.writeTo(codedOutput) - codedOutput.checkNoSpaceLeft() - SerializedPlan(Some(bytes)) - } + override def serializedPlanOpt: SerializedPlan = SerializedPlan( + Some(serializeNativeOp(nativeOp))) override def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) @@ -97,53 +102,72 @@ case class CometNativeWriteExec( "rows_written" -> SQLMetrics.createMetric(sparkContext, "number of written rows")) override def doExecute(): RDD[InternalRow] = { - // Setup job if committer is present - committer.foreach { c => - val jobContext = createJobContext() - c.setupJob(jobContext) - } + executeWriteAndCommit() + // Write operations do not return rows. + sparkContext.emptyRDD[InternalRow] + } - // Execute the native write with commit protocol - val resultRDD = doExecuteColumnar() + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + executeWriteAndCommit() + // Write operations do not return columnar batches. Spark may still ask for columnar output + // while this operator is nested under write planning nodes, so run the terminal write here too. + sparkContext.emptyRDD[ColumnarBatch] + } - // Force execution by consuming all batches - resultRDD - .mapPartitions { iter => - iter.foreach(_.close()) - Iterator.empty - } - .count() - - // Extract write statistics from metrics - val filesWritten = metrics("files_written").value - val bytesWritten = metrics("bytes_written").value - val rowsWritten = metrics("rows_written").value - - // Collect TaskCommitMessages from accumulator - val commitMessages = taskCommitMessagesAccum.value.asScala.toSeq - - // Commit job with collected TaskCommitMessages - committer.foreach { c => - val jobContext = createJobContext() - try { - c.commitJob(jobContext, commitMessages) - logInfo( - s"Successfully committed write job to $outputPath: " + - s"$filesWritten files, $bytesWritten bytes, $rowsWritten rows") - } catch { - case e: Exception => - logError("Failed to commit job, aborting", e) - c.abortJob(jobContext) - throw e - } + private def executeWriteAndCommit(): Unit = { + commitProtocol match { + case Some(protocol) => + val jobContext = createJobContext(protocol) + + // Match Spark's FileFormatWriter lifecycle: setupJob is outside the try block because it + // only initializes the job; failures after this point should abort the job. + protocol.committer.setupJob(jobContext) + + try { + val commitMessages = runNativeWriteJob(protocol) + protocol.committer.commitJob(jobContext, commitMessages.toSeq) + + val filesWritten = metrics("files_written").value + val bytesWritten = metrics("bytes_written").value + val rowsWritten = metrics("rows_written").value + logInfo( + s"Successfully committed native write job to $outputPath: " + + s"$filesWritten files, $bytesWritten bytes, $rowsWritten rows") + } catch { + case t: Throwable => + abortJob(protocol, jobContext, t) + throw t + } + + case None => + // Direct-write fallback for tests or callers that do not provide a commit protocol. + nativeWriteTasks(None).count() } + } - // Return empty RDD as write operations don't return data - sparkContext.emptyRDD[InternalRow] + private def runNativeWriteJob(protocol: CommitProtocolConfig): Array[TaskCommitMessage] = { + val writeRDD = nativeWriteTasks(Some(protocol)) + val ret = new Array[TaskCommitMessage](writeRDD.partitions.length) + + sparkContext.runJob( + writeRDD, + (_: TaskContext, iter: Iterator[TaskCommitMessage]) => { + assert(iter.hasNext, "Native write task did not return a commit message") + val commitMessage = iter.next() + assert(!iter.hasNext, "Native write task returned more than one commit message") + commitMessage + }, + writeRDD.partitions.indices, + (index, commitMessage: TaskCommitMessage) => { + protocol.committer.onTaskCommit(commitMessage) + ret(index) = commitMessage + }) + + ret } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - // Get the input data from the child operator + private def nativeWriteTasks( + capturedCommitProtocol: Option[CommitProtocolConfig]): RDD[TaskCommitMessage] = { val childRDD = if (child.supportsColumnar) { child.executeColumnar() } else { @@ -156,137 +180,114 @@ case class CometNativeWriteExec( } } - // Capture metadata before the transformation val numPartitions = childRDD.getNumPartitions val numOutputCols = child.output.length - val capturedCommitter = committer - val capturedJobTrackerID = jobTrackerID val capturedNativeOp = nativeOp - val capturedAccumulator = taskCommitMessagesAccum // Capture accumulator for use in tasks + val writeExec = this - // Execute native write operation with task-level commit protocol childRDD.mapPartitionsInternal { iter => - val partitionId = org.apache.spark.TaskContext.getPartitionId() - val taskAttemptId = org.apache.spark.TaskContext.get().taskAttemptId() + val sparkTaskContext = TaskContext.get() + val partitionId = sparkTaskContext.partitionId() + val sparkStageId = sparkTaskContext.stageId() + val taskAttemptId = sparkTaskContext.taskAttemptId() + val sparkAttemptNumber = taskAttemptId.toInt & Integer.MAX_VALUE - // Setup task-level commit protocol if provided - val (workDir, taskContext, commitMsg) = capturedCommitter - .map { committer => - val taskContext = - createTaskContext(capturedJobTrackerID, partitionId, taskAttemptId.toInt) + new Iterator[TaskCommitMessage] { + private var emitted = false - // Setup task - this creates the temporary working directory - committer.setupTask(taskContext) + override def hasNext: Boolean = !emitted - // Get the work directory for temp files - // Spark 4.1 made the (taskContext, dir, ext: String) overload throw by default; - // the FileNameSpec overload is the supported one and exists in 3.4+. - val workPath = committer.newTaskTempFile(taskContext, None, FileNameSpec("", "")) - val workDir = new Path(workPath).getParent.toString + override def next(): TaskCommitMessage = { + if (emitted) { + throw new NoSuchElementException("Native write task already completed") + } + emitted = true - (Some(workDir), Some((committer, taskContext)), null) - } - .getOrElse((None, None, null)) - - // Modify the native operator to include task-specific parameters - val modifiedNativeOp = if (workDir.isDefined) { - val parquetWriter = capturedNativeOp.getParquetWriter.toBuilder - .setWorkDir(workDir.get) - .setJobId(capturedJobTrackerID) - .setTaskAttemptId(taskAttemptId.toInt) - .build() - - capturedNativeOp.toBuilder.setParquetWriter(parquetWriter).build() - } else { - capturedNativeOp - } + var taskCommitter: Option[(FileCommitProtocol, TaskAttemptContext)] = None + var execIterator: CometExecIterator = null - val nativeMetrics = CometMetricNode.fromCometPlan(this) - // Register before CometExecIterator so completion listeners run after iterator close - // (Spark runs task completion callbacks in reverse registration order). - Option(TaskContext.get()).foreach(nativeMetrics.reportNativeWriteOutputMetrics) - - val size = modifiedNativeOp.getSerializedSize - val planBytes = new Array[Byte](size) - val codedOutput = CodedOutputStream.newInstance(planBytes) - modifiedNativeOp.writeTo(codedOutput) - codedOutput.checkNoSpaceLeft() - - val execIterator = new CometExecIterator( - CometExec.newIterId, - CometArrowStream.inputObjects( - iter, - CometUtils.fromAttributes(child.output), - "CometNativeWriteExec"), - numOutputCols, - planBytes, - nativeMetrics, - numPartitions, - partitionId, - None, - Seq.empty) - - // Wrap the iterator to handle task commit/abort and capture TaskCommitMessage - new Iterator[ColumnarBatch] { - private var completed = false - private var thrownException: Option[Throwable] = None - - override def hasNext: Boolean = { - val result = - try { - execIterator.hasNext - } catch { - case e: Throwable => - thrownException = Some(e) - handleTaskEnd() - throw e + try { + val taskOutputPath = capturedCommitProtocol match { + case Some(protocol) => + val taskAttemptContext = + createTaskContext(protocol, sparkStageId, partitionId, sparkAttemptNumber) + protocol.committer.setupTask(taskAttemptContext) + taskCommitter = Some(protocol.committer -> taskAttemptContext) + + val fileExtension = + protocol.outputWriterFactory.getFileExtension(taskAttemptContext) + protocol.committer.newTaskTempFile( + taskAttemptContext, + None, + FileNameSpec("", "-c000" + fileExtension)) + + case None => + directTaskOutputPath(partitionId) } - if (!result && !completed) { - handleTaskEnd() - } + val parquetWriter = capturedNativeOp.getParquetWriter.toBuilder + .setTaskOutputPath(taskOutputPath) + .setTaskAttemptId(sparkAttemptNumber) - result - } + capturedCommitProtocol.foreach { protocol => + parquetWriter.setJobId(protocol.commitProtocolJobId) + } - override def next(): ColumnarBatch = { - try { - execIterator.next() - } catch { - case e: Throwable => - thrownException = Some(e) - handleTaskEnd() - throw e - } - } + val modifiedNativeOp = capturedNativeOp.toBuilder + .setParquetWriter(parquetWriter.build()) + .build() + + val nativeMetrics = CometMetricNode.fromCometPlan(writeExec) + // Register before CometExecIterator so completion listeners run after iterator close + // (Spark runs task completion callbacks in reverse registration order). + Option(TaskContext.get()).foreach(nativeMetrics.reportNativeWriteOutputMetrics) + + execIterator = new CometExecIterator( + CometExec.newIterId, + CometArrowStream.inputObjects( + iter, + CometUtils.fromAttributes(child.output), + "CometNativeWriteExec"), + numOutputCols, + serializeNativeOp(modifiedNativeOp), + nativeMetrics, + numPartitions, + partitionId, + None, + Seq.empty) + + while (execIterator.hasNext) { + execIterator.next().close() + } - private def handleTaskEnd(): Unit = { - if (!completed) { - completed = true - - // Handle commit or abort based on whether an exception was thrown - taskContext.foreach { case (committer, ctx) => - try { - if (thrownException.isEmpty) { - // Commit the task and add message to accumulator - val message = committer.commitTask(ctx) - capturedAccumulator.add(message) - logInfo(s"Task ${ctx.getTaskAttemptID} committed successfully") - } else { - // Abort the task - committer.abortTask(ctx) - val exMsg = thrownException.get.getMessage - logWarning(s"Task ${ctx.getTaskAttemptID} aborted due to exception: $exMsg") + taskCommitter match { + case Some((committer, taskAttemptContext)) => + val message = committer.commitTask(taskAttemptContext) + logInfo(s"Task ${taskAttemptContext.getTaskAttemptID} committed successfully") + message + case None => + FileCommitProtocol.EmptyTaskCommitMessage + } + } catch { + case t: Throwable => + taskCommitter.foreach { case (committer, taskAttemptContext) => + try { + committer.abortTask(taskAttemptContext) + logWarning( + s"Task ${taskAttemptContext.getTaskAttemptID} aborted due to exception: " + + Option(t.getMessage).getOrElse(t.getClass.getName)) + } catch { + case abortError: Throwable => + logWarning( + s"Error aborting task ${taskAttemptContext.getTaskAttemptID}", + abortError) + t.addSuppressed(abortError) } - } catch { - case e: Exception => - // Log the commit/abort exception but don't mask the original exception - logError(s"Error during task commit/abort: ${e.getMessage}", e) - if (thrownException.isEmpty) { - // If no original exception, propagate the commit/abort exception - throw e - } } + throw t + } finally { + if (execIterator != null) { + execIterator.close() } } } @@ -294,22 +295,56 @@ case class CometNativeWriteExec( } } - /** Create a JobContext for the write job */ - private def createJobContext(): Job = { - val job = Job.getInstance() - job.setJobID(new org.apache.hadoop.mapreduce.JobID(jobTrackerID, 0)) - job + private def serializeNativeOp(op: Operator): Array[Byte] = { + val size = op.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + op.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } + + private def directTaskOutputPath(partitionId: Int): String = { + val separator = if (outputPath.endsWith("/")) "" else "/" + f"${outputPath}${separator}part-$partitionId%05d.parquet" + } + + private def abortJob( + protocol: CommitProtocolConfig, + jobContext: Job, + cause: Throwable): Unit = { + logError("Native write failed, aborting job", cause) + try { + protocol.committer.abortJob(jobContext) + } catch { + case abortError: Throwable => + logWarning("Error aborting native write job", abortError) + cause.addSuppressed(abortError) + } + } + + /** Create a JobContext for the write job using the prepared Hadoop write configuration. */ + private def createJobContext(protocol: CommitProtocolConfig): Job = { + Job.getInstance(new Configuration(protocol.serializableHadoopConf.value)) } - /** Create a TaskAttemptContext for a specific task */ + /** Create a TaskAttemptContext matching Spark's FileFormatWriter task ID setup. */ private def createTaskContext( - jobId: String, - partitionId: Int, - attemptNumber: Int): TaskAttemptContext = { - val job = Job.getInstance() - val taskAttemptID = new TaskAttemptID( - new TaskID(new org.apache.hadoop.mapreduce.JobID(jobId, 0), TaskType.REDUCE, partitionId), - attemptNumber) - new TaskAttemptContextImpl(job.getConfiguration, taskAttemptID) + protocol: CommitProtocolConfig, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int): TaskAttemptContext = { + val hadoopConf = new Configuration(protocol.serializableHadoopConf.value) + val jobId = SparkHadoopWriterUtils.createJobID(protocol.jobTrackerID, sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) + + hadoopConf.set("mapreduce.job.id", jobId.toString) + hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) + + new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } } diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index f6795b91a3..1becf1e685 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -54,6 +54,7 @@ class CometParquetWriterSuite extends CometTestBase { writeWithCometNativeWriteExec(inputPath, outputPath) verifyWrittenFile(outputPath) + verifyCommitProtocolOutput(outputPath) } } } @@ -430,6 +431,28 @@ class CometParquetWriterSuite extends CometTestBase { Some(plan) } + private def verifyCommitProtocolOutput(outputPath: String): Unit = { + val outputDir = new File(outputPath) + val outputFiles = Option(outputDir.listFiles()).getOrElse(Array.empty) + val fileNames = outputFiles.map(_.getName).toSeq + + assert( + fileNames.contains("_SUCCESS"), + s"Expected Spark commit protocol to create _SUCCESS marker, found: ${fileNames.mkString(", ")}") + assert( + !fileNames.contains("_temporary"), + s"Expected temporary commit directory to be cleaned up, found: ${fileNames.mkString(", ")}") + assert( + !fileNames.exists(_.startsWith(".spark-staging-")), + s"Expected staging commit directory to be cleaned up, found: ${fileNames.mkString(", ")}") + + val partFileNames = fileNames.filter(_.startsWith("part-")) + assert(partFileNames.nonEmpty, s"Expected part files, found: ${fileNames.mkString(", ")}") + assert( + partFileNames.forall(name => name.contains("-c000") && name.endsWith(".parquet")), + s"Expected Spark commit protocol part-file names, found: ${partFileNames.mkString(", ")}") + } + private def verifyWrittenFile(outputPath: String): Unit = { // Verify the data was written correctly val resultDf = spark.read.parquet(outputPath)