diff --git a/native/core/src/execution/operators/parquet_writer.rs b/native/core/src/execution/operators/parquet_writer.rs index f1168c4a57..2777520abe 100644 --- a/native/core/src/execution/operators/parquet_writer.rs +++ b/native/core/src/execution/operators/parquet_writer.rs @@ -193,10 +193,10 @@ impl ParquetWriter { pub struct ParquetWriterExec { /// Input execution plan input: Arc, - /// Output file path (final destination) + /// Output directory path (final destination), used for display/debugging output_path: String, - /// Working directory for temporary files (used by FileCommitProtocol) - work_dir: String, + /// Full file path to write for this Spark task attempt + task_output_path: String, /// Job ID for tracking this write operation job_id: Option, /// Task attempt ID for this specific task @@ -221,7 +221,7 @@ impl ParquetWriterExec { pub fn try_new( input: Arc, output_path: String, - work_dir: String, + task_output_path: String, job_id: Option, task_attempt_id: Option, compression: CompressionCodec, @@ -242,7 +242,7 @@ impl ParquetWriterExec { Ok(ParquetWriterExec { input, output_path, - work_dir, + task_output_path, job_id, task_attempt_id, compression, @@ -431,7 +431,7 @@ impl ExecutionPlan for ParquetWriterExec { 1 => Ok(Arc::new(ParquetWriterExec::try_new( Arc::clone(&children[0]), self.output_path.clone(), - self.work_dir.clone(), + self.task_output_path.clone(), self.job_id.clone(), self.task_attempt_id, self.compression.clone(), @@ -460,8 +460,7 @@ impl ExecutionPlan for ParquetWriterExec { let runtime_env = context.runtime_env(); let input = self.input.execute(partition, context)?; let input_schema = self.input.schema(); - let work_dir = self.work_dir.clone(); - let task_attempt_id = self.task_attempt_id; + let part_file = self.task_output_path.clone(); let compression = self.compression_to_parquet()?; let column_names = self.column_names.clone(); @@ -476,16 +475,8 @@ impl ExecutionPlan for ParquetWriterExec { .collect(); let output_schema = Arc::new(arrow::datatypes::Schema::new(fields)); - // Generate part file name for this partition - // If using FileCommitProtocol (work_dir is set), include task_attempt_id in the filename - let part_file = if let Some(attempt_id) = task_attempt_id { - format!( - "{}/part-{:05}-{:05}.parquet", - work_dir, self.partition_id, attempt_id - ) - } else { - format!("{}/part-{:05}.parquet", work_dir, self.partition_id) - }; + // The path is generated by Spark's FileCommitProtocol. Writing exactly to this + // path allows Spark to commit or abort the task using its standard protocol. // Configure writer properties let props = WriterProperties::builder() @@ -809,14 +800,16 @@ mod tests { let memory_exec = Arc::new(DataSourceExec::new(Arc::new(memory_source_config))); // Create ParquetWriterExec with DataSourceExec as input - let output_path = "unused".to_string(); - let work_dir = "hdfs://namenode:9000/user/test_parquet_writer_exec".to_string(); + let output_path = "hdfs://namenode:9000/user/test_parquet_writer_exec".to_string(); + let task_output_path = + "hdfs://namenode:9000/user/test_parquet_writer_exec/part-00000-c000.parquet" + .to_string(); let column_names = vec!["id".to_string(), "name".to_string()]; let parquet_writer = ParquetWriterExec::try_new( memory_exec, output_path, - work_dir, + task_output_path, None, // job_id Some(123), // task_attempt_id CompressionCodec::None, diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 553fe5215c..2168090e00 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1662,14 +1662,30 @@ impl PhysicalPlanner { .map(|(k, v)| (k.clone(), v.clone())) .collect(); + let task_output_path = writer + .task_output_path + .clone() + .or_else(|| { + // Compatibility with older serialized plans that only provided a work + // directory. New JVM code always sends task_output_path from Spark's + // FileCommitProtocol so native code does not invent filenames. + writer + .work_dir + .as_ref() + .map(|work_dir| match writer.task_attempt_id { + Some(attempt_id) => format!( + "{}/part-{:05}-{:05}.parquet", + work_dir, self.partition, attempt_id + ), + None => format!("{}/part-{:05}.parquet", work_dir, self.partition), + }) + }) + .expect("task_output_path is provided"); + let parquet_writer = Arc::new(ParquetWriterExec::try_new( Arc::clone(&child.native_plan), writer.output_path.clone(), - writer - .work_dir - .as_ref() - .expect("work_dir is provided") - .clone(), + task_output_path, writer.job_id.clone(), writer.task_attempt_id, codec, diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 9d81d2853b..d957238b45 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -347,8 +347,7 @@ message ParquetWriter { string output_path = 1; CompressionCodec compression = 2; repeated string column_names = 4; - // Working directory for temporary files (used by FileCommitProtocol) - // If not set, files are written directly to output_path + // Working directory for temporary files (legacy path used before task_output_path existed). optional string work_dir = 5; // Job ID for tracking this write operation optional string job_id = 6; @@ -361,6 +360,9 @@ message ParquetWriter { // configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in // the map. map object_store_options = 8; + // Full temporary output file path returned by Spark's FileCommitProtocol for this task. + // The native writer must write exactly to this path so Spark can commit or abort it. + optional string task_output_path = 9; } enum AggregateMode { diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 4ae73565c6..0f1e4c6057 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -20,16 +20,22 @@ package org.apache.comet.serde.operator import java.net.URI -import java.util.Locale +import java.util.{Locale, UUID} import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkException +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.SerializableConfiguration import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.withFallbackReason @@ -132,8 +138,8 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec .setOutputPath(outputPath) .setCompression(codec) .addAllColumnNames(cmd.query.output.map(_.name).asJava) - // Note: work_dir, job_id, and task_attempt_id will be set at execution time - // in CometNativeWriteExec, as they depend on the Spark task context + // The task_output_path is filled in by CometNativeWriteExec at execution time from + // Spark's FileCommitProtocol, because it depends on the Spark task context. // Collect S3/cloud storage configurations val session = op.session @@ -178,29 +184,35 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec other } - // Create FileCommitProtocol for atomic writes - val jobId = java.util.UUID.randomUUID().toString - val committer = - try { - // Use Spark's SQLHadoopMapReduceCommitProtocol - val committerClass = - classOf[org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol] - val constructor = - committerClass.getConstructor(classOf[String], classOf[String], classOf[Boolean]) - Some( - constructor - .newInstance( - jobId, - outputPath, - java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for now - ) - .asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol]) - } catch { - case e: Exception => - throw new SparkException(s"Could not instantiate FileCommitProtocol: ${e.getMessage}") - } - - CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId) + val session = op.session + val hadoopConf = session.sessionState.newHadoopConfWithOptions(cmd.options) + val job = Job.getInstance(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, new Path(outputPath)) + + val outputWriterFactory = + cmd.fileFormat.prepareWrite(session, job, CaseInsensitiveMap(cmd.options), cmd.query.schema) + + val commitProtocolJobId = UUID.randomUUID().toString + val committer = FileCommitProtocol.instantiate( + session.sessionState.conf.fileCommitProtocolClass, + commitProtocolJobId, + outputPath, + false) + + // Match Spark's FileFormatWriter behavior: propagate a per-write UUID in the Hadoop + // configuration before it is serialized to executors. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", UUID.randomUUID().toString) + + val commitProtocol = CometNativeWriteExec.CommitProtocolConfig( + committer = committer, + serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), + outputWriterFactory = outputWriterFactory, + commitProtocolJobId = commitProtocolJobId, + jobTrackerID = CometNativeWriteExec.newJobTrackerID()) + + CometNativeWriteExec(nativeOp, childPlan, outputPath, Some(commitProtocol)) } private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 4cfdd11361..aa86a824f5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -19,72 +19,77 @@ package org.apache.spark.sql.comet -import scala.jdk.CollectionConverters._ +import java.util.Date -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.TaskContext -import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} +import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.CometNativeWriteExec.CommitProtocolConfig import org.apache.spark.sql.comet.execution.arrow.CometArrowStream import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.datasources.OutputWriterFactory import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils +import org.apache.spark.util.SerializableConfiguration import com.google.protobuf.CodedOutputStream import org.apache.comet.CometExecIterator import org.apache.comet.serde.OperatorOuterClass.Operator +object CometNativeWriteExec { + + def newJobTrackerID(): String = SparkHadoopWriterUtils.createJobTrackerID(new Date()) + + /** + * Driver-created objects required to use Spark's FileCommitProtocol from native write tasks. + * The committer instance is serializable by contract and is sent to executors, while the same + * driver-side instance receives task commit messages and commits or aborts the job. + */ + case class CommitProtocolConfig( + committer: FileCommitProtocol, + serializableHadoopConf: SerializableConfiguration, + outputWriterFactory: OutputWriterFactory, + commitProtocolJobId: String, + jobTrackerID: String) + extends Serializable +} + /** - * Comet physical operator for native Parquet write operations with FileCommitProtocol support. - * - * This operator writes data to Parquet files using the native Comet engine. It integrates with - * Spark's FileCommitProtocol to provide atomic writes with proper staging and commit semantics. + * Comet physical operator for native Parquet write operations. * - * The implementation includes support for Spark's file commit protocol through work_dir, job_id, - * and task_attempt_id parameters that can be set in the operator. When work_dir is set, files are - * written to a temporary location that can be atomically committed later. + * When [[commitProtocol]] is present, this operator follows Spark's FileCommitProtocol lifecycle: + * driver setupJob, executor setupTask/newTaskTempFile/commitTask or abortTask, and driver + * commitJob or abortJob. Native code writes exactly to the task temp file returned by Spark's + * commit protocol; it does not generate its own final part-file names. * * @param nativeOp - * The native operator representing the write operation (template, will be modified per task) + * The native operator representing the write operation (template, modified per task) * @param child * The child operator providing the data to write * @param outputPath - * The path where the Parquet file will be written - * @param committer - * FileCommitProtocol for atomic writes. If None, files are written directly. - * @param jobTrackerID - * Unique identifier for this write job + * The final output directory for the write + * @param commitProtocol + * Spark file commit protocol state. If absent, files are written directly under outputPath. */ case class CometNativeWriteExec( nativeOp: Operator, child: SparkPlan, outputPath: String, - committer: Option[FileCommitProtocol] = None, - jobTrackerID: String = Utils.createTempDir().getName) + commitProtocol: Option[CommitProtocolConfig] = None) extends CometNativeExec with UnaryExecNode { override def originalPlan: SparkPlan = child - // Accumulator to collect TaskCommitMessages from all tasks - // Must be eagerly initialized on driver, not lazy - @transient private val taskCommitMessagesAccum = - sparkContext.collectionAccumulator[FileCommitProtocol.TaskCommitMessage]("taskCommitMessages") - - override def serializedPlanOpt: SerializedPlan = { - val size = nativeOp.getSerializedSize - val bytes = new Array[Byte](size) - val codedOutput = CodedOutputStream.newInstance(bytes) - nativeOp.writeTo(codedOutput) - codedOutput.checkNoSpaceLeft() - SerializedPlan(Some(bytes)) - } + override def serializedPlanOpt: SerializedPlan = SerializedPlan( + Some(serializeNativeOp(nativeOp))) override def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) @@ -97,53 +102,72 @@ case class CometNativeWriteExec( "rows_written" -> SQLMetrics.createMetric(sparkContext, "number of written rows")) override def doExecute(): RDD[InternalRow] = { - // Setup job if committer is present - committer.foreach { c => - val jobContext = createJobContext() - c.setupJob(jobContext) - } + executeWriteAndCommit() + // Write operations do not return rows. + sparkContext.emptyRDD[InternalRow] + } - // Execute the native write with commit protocol - val resultRDD = doExecuteColumnar() + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + executeWriteAndCommit() + // Write operations do not return columnar batches. Spark may still ask for columnar output + // while this operator is nested under write planning nodes, so run the terminal write here too. + sparkContext.emptyRDD[ColumnarBatch] + } - // Force execution by consuming all batches - resultRDD - .mapPartitions { iter => - iter.foreach(_.close()) - Iterator.empty - } - .count() - - // Extract write statistics from metrics - val filesWritten = metrics("files_written").value - val bytesWritten = metrics("bytes_written").value - val rowsWritten = metrics("rows_written").value - - // Collect TaskCommitMessages from accumulator - val commitMessages = taskCommitMessagesAccum.value.asScala.toSeq - - // Commit job with collected TaskCommitMessages - committer.foreach { c => - val jobContext = createJobContext() - try { - c.commitJob(jobContext, commitMessages) - logInfo( - s"Successfully committed write job to $outputPath: " + - s"$filesWritten files, $bytesWritten bytes, $rowsWritten rows") - } catch { - case e: Exception => - logError("Failed to commit job, aborting", e) - c.abortJob(jobContext) - throw e - } + private def executeWriteAndCommit(): Unit = { + commitProtocol match { + case Some(protocol) => + val jobContext = createJobContext(protocol) + + // Match Spark's FileFormatWriter lifecycle: setupJob is outside the try block because it + // only initializes the job; failures after this point should abort the job. + protocol.committer.setupJob(jobContext) + + try { + val commitMessages = runNativeWriteJob(protocol) + protocol.committer.commitJob(jobContext, commitMessages.toSeq) + + val filesWritten = metrics("files_written").value + val bytesWritten = metrics("bytes_written").value + val rowsWritten = metrics("rows_written").value + logInfo( + s"Successfully committed native write job to $outputPath: " + + s"$filesWritten files, $bytesWritten bytes, $rowsWritten rows") + } catch { + case t: Throwable => + abortJob(protocol, jobContext, t) + throw t + } + + case None => + // Direct-write fallback for tests or callers that do not provide a commit protocol. + nativeWriteTasks(None).count() } + } - // Return empty RDD as write operations don't return data - sparkContext.emptyRDD[InternalRow] + private def runNativeWriteJob(protocol: CommitProtocolConfig): Array[TaskCommitMessage] = { + val writeRDD = nativeWriteTasks(Some(protocol)) + val ret = new Array[TaskCommitMessage](writeRDD.partitions.length) + + sparkContext.runJob( + writeRDD, + (_: TaskContext, iter: Iterator[TaskCommitMessage]) => { + assert(iter.hasNext, "Native write task did not return a commit message") + val commitMessage = iter.next() + assert(!iter.hasNext, "Native write task returned more than one commit message") + commitMessage + }, + writeRDD.partitions.indices, + (index, commitMessage: TaskCommitMessage) => { + protocol.committer.onTaskCommit(commitMessage) + ret(index) = commitMessage + }) + + ret } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - // Get the input data from the child operator + private def nativeWriteTasks( + capturedCommitProtocol: Option[CommitProtocolConfig]): RDD[TaskCommitMessage] = { val childRDD = if (child.supportsColumnar) { child.executeColumnar() } else { @@ -156,137 +180,114 @@ case class CometNativeWriteExec( } } - // Capture metadata before the transformation val numPartitions = childRDD.getNumPartitions val numOutputCols = child.output.length - val capturedCommitter = committer - val capturedJobTrackerID = jobTrackerID val capturedNativeOp = nativeOp - val capturedAccumulator = taskCommitMessagesAccum // Capture accumulator for use in tasks + val writeExec = this - // Execute native write operation with task-level commit protocol childRDD.mapPartitionsInternal { iter => - val partitionId = org.apache.spark.TaskContext.getPartitionId() - val taskAttemptId = org.apache.spark.TaskContext.get().taskAttemptId() + val sparkTaskContext = TaskContext.get() + val partitionId = sparkTaskContext.partitionId() + val sparkStageId = sparkTaskContext.stageId() + val taskAttemptId = sparkTaskContext.taskAttemptId() + val sparkAttemptNumber = taskAttemptId.toInt & Integer.MAX_VALUE - // Setup task-level commit protocol if provided - val (workDir, taskContext, commitMsg) = capturedCommitter - .map { committer => - val taskContext = - createTaskContext(capturedJobTrackerID, partitionId, taskAttemptId.toInt) + new Iterator[TaskCommitMessage] { + private var emitted = false - // Setup task - this creates the temporary working directory - committer.setupTask(taskContext) + override def hasNext: Boolean = !emitted - // Get the work directory for temp files - // Spark 4.1 made the (taskContext, dir, ext: String) overload throw by default; - // the FileNameSpec overload is the supported one and exists in 3.4+. - val workPath = committer.newTaskTempFile(taskContext, None, FileNameSpec("", "")) - val workDir = new Path(workPath).getParent.toString + override def next(): TaskCommitMessage = { + if (emitted) { + throw new NoSuchElementException("Native write task already completed") + } + emitted = true - (Some(workDir), Some((committer, taskContext)), null) - } - .getOrElse((None, None, null)) - - // Modify the native operator to include task-specific parameters - val modifiedNativeOp = if (workDir.isDefined) { - val parquetWriter = capturedNativeOp.getParquetWriter.toBuilder - .setWorkDir(workDir.get) - .setJobId(capturedJobTrackerID) - .setTaskAttemptId(taskAttemptId.toInt) - .build() - - capturedNativeOp.toBuilder.setParquetWriter(parquetWriter).build() - } else { - capturedNativeOp - } + var taskCommitter: Option[(FileCommitProtocol, TaskAttemptContext)] = None + var execIterator: CometExecIterator = null - val nativeMetrics = CometMetricNode.fromCometPlan(this) - // Register before CometExecIterator so completion listeners run after iterator close - // (Spark runs task completion callbacks in reverse registration order). - Option(TaskContext.get()).foreach(nativeMetrics.reportNativeWriteOutputMetrics) - - val size = modifiedNativeOp.getSerializedSize - val planBytes = new Array[Byte](size) - val codedOutput = CodedOutputStream.newInstance(planBytes) - modifiedNativeOp.writeTo(codedOutput) - codedOutput.checkNoSpaceLeft() - - val execIterator = new CometExecIterator( - CometExec.newIterId, - CometArrowStream.inputObjects( - iter, - CometUtils.fromAttributes(child.output), - "CometNativeWriteExec"), - numOutputCols, - planBytes, - nativeMetrics, - numPartitions, - partitionId, - None, - Seq.empty) - - // Wrap the iterator to handle task commit/abort and capture TaskCommitMessage - new Iterator[ColumnarBatch] { - private var completed = false - private var thrownException: Option[Throwable] = None - - override def hasNext: Boolean = { - val result = - try { - execIterator.hasNext - } catch { - case e: Throwable => - thrownException = Some(e) - handleTaskEnd() - throw e + try { + val taskOutputPath = capturedCommitProtocol match { + case Some(protocol) => + val taskAttemptContext = + createTaskContext(protocol, sparkStageId, partitionId, sparkAttemptNumber) + protocol.committer.setupTask(taskAttemptContext) + taskCommitter = Some(protocol.committer -> taskAttemptContext) + + val fileExtension = + protocol.outputWriterFactory.getFileExtension(taskAttemptContext) + protocol.committer.newTaskTempFile( + taskAttemptContext, + None, + FileNameSpec("", "-c000" + fileExtension)) + + case None => + directTaskOutputPath(partitionId) } - if (!result && !completed) { - handleTaskEnd() - } + val parquetWriter = capturedNativeOp.getParquetWriter.toBuilder + .setTaskOutputPath(taskOutputPath) + .setTaskAttemptId(sparkAttemptNumber) - result - } + capturedCommitProtocol.foreach { protocol => + parquetWriter.setJobId(protocol.commitProtocolJobId) + } - override def next(): ColumnarBatch = { - try { - execIterator.next() - } catch { - case e: Throwable => - thrownException = Some(e) - handleTaskEnd() - throw e - } - } + val modifiedNativeOp = capturedNativeOp.toBuilder + .setParquetWriter(parquetWriter.build()) + .build() + + val nativeMetrics = CometMetricNode.fromCometPlan(writeExec) + // Register before CometExecIterator so completion listeners run after iterator close + // (Spark runs task completion callbacks in reverse registration order). + Option(TaskContext.get()).foreach(nativeMetrics.reportNativeWriteOutputMetrics) + + execIterator = new CometExecIterator( + CometExec.newIterId, + CometArrowStream.inputObjects( + iter, + CometUtils.fromAttributes(child.output), + "CometNativeWriteExec"), + numOutputCols, + serializeNativeOp(modifiedNativeOp), + nativeMetrics, + numPartitions, + partitionId, + None, + Seq.empty) + + while (execIterator.hasNext) { + execIterator.next().close() + } - private def handleTaskEnd(): Unit = { - if (!completed) { - completed = true - - // Handle commit or abort based on whether an exception was thrown - taskContext.foreach { case (committer, ctx) => - try { - if (thrownException.isEmpty) { - // Commit the task and add message to accumulator - val message = committer.commitTask(ctx) - capturedAccumulator.add(message) - logInfo(s"Task ${ctx.getTaskAttemptID} committed successfully") - } else { - // Abort the task - committer.abortTask(ctx) - val exMsg = thrownException.get.getMessage - logWarning(s"Task ${ctx.getTaskAttemptID} aborted due to exception: $exMsg") + taskCommitter match { + case Some((committer, taskAttemptContext)) => + val message = committer.commitTask(taskAttemptContext) + logInfo(s"Task ${taskAttemptContext.getTaskAttemptID} committed successfully") + message + case None => + FileCommitProtocol.EmptyTaskCommitMessage + } + } catch { + case t: Throwable => + taskCommitter.foreach { case (committer, taskAttemptContext) => + try { + committer.abortTask(taskAttemptContext) + logWarning( + s"Task ${taskAttemptContext.getTaskAttemptID} aborted due to exception: " + + Option(t.getMessage).getOrElse(t.getClass.getName)) + } catch { + case abortError: Throwable => + logWarning( + s"Error aborting task ${taskAttemptContext.getTaskAttemptID}", + abortError) + t.addSuppressed(abortError) } - } catch { - case e: Exception => - // Log the commit/abort exception but don't mask the original exception - logError(s"Error during task commit/abort: ${e.getMessage}", e) - if (thrownException.isEmpty) { - // If no original exception, propagate the commit/abort exception - throw e - } } + throw t + } finally { + if (execIterator != null) { + execIterator.close() } } } @@ -294,22 +295,56 @@ case class CometNativeWriteExec( } } - /** Create a JobContext for the write job */ - private def createJobContext(): Job = { - val job = Job.getInstance() - job.setJobID(new org.apache.hadoop.mapreduce.JobID(jobTrackerID, 0)) - job + private def serializeNativeOp(op: Operator): Array[Byte] = { + val size = op.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + op.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } + + private def directTaskOutputPath(partitionId: Int): String = { + val separator = if (outputPath.endsWith("/")) "" else "/" + f"${outputPath}${separator}part-$partitionId%05d.parquet" + } + + private def abortJob( + protocol: CommitProtocolConfig, + jobContext: Job, + cause: Throwable): Unit = { + logError("Native write failed, aborting job", cause) + try { + protocol.committer.abortJob(jobContext) + } catch { + case abortError: Throwable => + logWarning("Error aborting native write job", abortError) + cause.addSuppressed(abortError) + } + } + + /** Create a JobContext for the write job using the prepared Hadoop write configuration. */ + private def createJobContext(protocol: CommitProtocolConfig): Job = { + Job.getInstance(new Configuration(protocol.serializableHadoopConf.value)) } - /** Create a TaskAttemptContext for a specific task */ + /** Create a TaskAttemptContext matching Spark's FileFormatWriter task ID setup. */ private def createTaskContext( - jobId: String, - partitionId: Int, - attemptNumber: Int): TaskAttemptContext = { - val job = Job.getInstance() - val taskAttemptID = new TaskAttemptID( - new TaskID(new org.apache.hadoop.mapreduce.JobID(jobId, 0), TaskType.REDUCE, partitionId), - attemptNumber) - new TaskAttemptContextImpl(job.getConfiguration, taskAttemptID) + protocol: CommitProtocolConfig, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int): TaskAttemptContext = { + val hadoopConf = new Configuration(protocol.serializableHadoopConf.value) + val jobId = SparkHadoopWriterUtils.createJobID(protocol.jobTrackerID, sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) + + hadoopConf.set("mapreduce.job.id", jobId.toString) + hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) + + new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } } diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index f6795b91a3..1becf1e685 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -54,6 +54,7 @@ class CometParquetWriterSuite extends CometTestBase { writeWithCometNativeWriteExec(inputPath, outputPath) verifyWrittenFile(outputPath) + verifyCommitProtocolOutput(outputPath) } } } @@ -430,6 +431,28 @@ class CometParquetWriterSuite extends CometTestBase { Some(plan) } + private def verifyCommitProtocolOutput(outputPath: String): Unit = { + val outputDir = new File(outputPath) + val outputFiles = Option(outputDir.listFiles()).getOrElse(Array.empty) + val fileNames = outputFiles.map(_.getName).toSeq + + assert( + fileNames.contains("_SUCCESS"), + s"Expected Spark commit protocol to create _SUCCESS marker, found: ${fileNames.mkString(", ")}") + assert( + !fileNames.contains("_temporary"), + s"Expected temporary commit directory to be cleaned up, found: ${fileNames.mkString(", ")}") + assert( + !fileNames.exists(_.startsWith(".spark-staging-")), + s"Expected staging commit directory to be cleaned up, found: ${fileNames.mkString(", ")}") + + val partFileNames = fileNames.filter(_.startsWith("part-")) + assert(partFileNames.nonEmpty, s"Expected part files, found: ${fileNames.mkString(", ")}") + assert( + partFileNames.forall(name => name.contains("-c000") && name.endsWith(".parquet")), + s"Expected Spark commit protocol part-file names, found: ${partFileNames.mkString(", ")}") + } + private def verifyWrittenFile(outputPath: String): Unit = { // Verify the data was written correctly val resultDf = spark.read.parquet(outputPath)