diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh index 084936475d..e4cd7224b0 100755 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -93,6 +93,7 @@ allowed_expr+="|^org/apache/spark/sql/$" allowed_expr+="|^org/apache/spark/sql/ExtendedExplainGenerator.*$" allowed_expr+="|^org/apache/spark/CometPlugin.class$" allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$" +allowed_expr+="|^org/apache/spark/CometExecutorPlugin.*$" allowed_expr+="|^org/apache/spark/CometSource.*$" allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.class$" allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.*$" diff --git a/docs/source/contributor-guide/development.md b/docs/source/contributor-guide/development.md index 3785358cf3..7e5c20c7ec 100644 --- a/docs/source/contributor-guide/development.md +++ b/docs/source/contributor-guide/development.md @@ -57,9 +57,11 @@ method on any worker thread, and may move it between threads across polls. Any s in the operator struct or be shared via `Arc`. **JNI calls work from any thread, but have overhead.** `JVMClasses::get_env()` calls -`AttachCurrentThread`, which acquires JVM internal locks. The `AttachGuard` detaches the thread -when dropped. Repeated attach/detach cycles on tokio workers add overhead, so avoid calling -into the JVM on hot paths during stream execution. +`AttachCurrentThread`, which acquires JVM internal locks. The attachment is cached in +thread-local storage and is only released when the worker thread itself exits, not when the +`AttachGuard` is dropped. This is why the tokio runtime has to be shut down (releasing its +worker threads) for the JVM to be able to exit. Acquiring the JVM locks on each `get_env()` +call adds overhead, so avoid calling into the JVM on hot paths during stream execution. **Do not call `TaskContext.get()` from JVM callbacks during execution.** Spark's `TaskContext` is a `ThreadLocal` on the executor task thread. JVM methods invoked from tokio worker threads will @@ -84,7 +86,9 @@ to unwrap decryption keys during Parquet reads. It uses a stored `GlobalRef` and ### The tokio runtime -The runtime is created once per executor JVM in a `Lazy` static: +The runtime is stored in a `Mutex>` static and created lazily on first use. It +is torn down on plugin shutdown (via `release_runtime`) so that the tokio worker threads exit +and the JVM can shut down cleanly: - **Worker threads:** `num_cpus` by default, configurable via `COMET_WORKER_THREADS` - **Max blocking threads:** 512 by default, configurable via `COMET_MAX_BLOCKING_THREADS` diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 6dc00e9cf6..e97b504732 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -91,7 +91,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; use std::{sync::Arc, task::Poll}; -use tokio::runtime::Runtime; +use tokio::runtime::{Handle, Runtime}; use tokio::sync::mpsc; use crate::execution::memory_pools::{ @@ -118,7 +118,7 @@ use std::sync::OnceLock; #[cfg(feature = "jemalloc")] use tikv_jemalloc_ctl::{epoch, stats}; -static TOKIO_RUNTIME: OnceLock = OnceLock::new(); +static TOKIO_RUNTIME: Mutex> = Mutex::new(None); #[cfg(feature = "jemalloc")] fn log_jemalloc_usage() { @@ -212,12 +212,39 @@ fn build_runtime(default_worker_threads: Option) -> Runtime { /// Initialize the global Tokio runtime with the given default worker thread count. /// If the runtime is already initialized, this is a no-op. pub fn init_runtime(default_worker_threads: usize) { - TOKIO_RUNTIME.get_or_init(|| build_runtime(Some(default_worker_threads))); + let mut guard = TOKIO_RUNTIME.lock(); + if guard.is_none() { + *guard = Some(build_runtime(Some(default_worker_threads))); + } +} + +/// Returns a handle to the global Tokio runtime, lazily initializing it if needed. +/// +/// A [`Handle`] is returned (rather than a `&'static Runtime`) so that the runtime +/// can be torn down via [`release_runtime`]. The handle is cheap to clone and can be +/// used with `spawn` / `block_on` just like a `Runtime`. +pub fn get_runtime() -> Handle { + let mut guard = TOKIO_RUNTIME.lock(); + guard + .get_or_insert_with(|| build_runtime(None)) + .handle() + .clone() } -/// Function to get a handle to the global Tokio runtime -pub fn get_runtime() -> &'static Runtime { - TOKIO_RUNTIME.get_or_init(|| build_runtime(None)) +/// Tears down the global Tokio runtime, if it has been initialized. +/// +/// The runtime is moved out of the global slot and shut down in the background so the +/// calling (JNI) thread is not blocked waiting for worker threads to finish. Any handles +/// previously returned by [`get_runtime`] will start failing their spawns once the runtime +/// is gone, so this must only be called when no native execution is in flight. +/// +/// Must not be called from within the runtime's own worker threads, otherwise the shutdown +/// would deadlock/panic. +pub fn release_runtime() { + let runtime = TOKIO_RUNTIME.lock().take(); + if let Some(runtime) = runtime { + runtime.shutdown_timeout(Duration::from_secs(3)); + } } /// Returns a short name for an OpStruct variant. diff --git a/native/core/src/execution/operators/iceberg_scan.rs b/native/core/src/execution/operators/iceberg_scan.rs index 713b4089b0..090f5813ac 100644 --- a/native/core/src/execution/operators/iceberg_scan.rs +++ b/native/core/src/execution/operators/iceberg_scan.rs @@ -176,16 +176,24 @@ impl IcebergScanExec { let task_stream = futures::stream::iter(tasks.into_iter().map(Ok)).boxed(); - // iceberg-rust's ArrowReader spawns IO/CPU work onto an iceberg::Runtime. execute() runs - // on the JVM-called thread outside any tokio context, so Runtime::current() would panic; - // build it from Comet's global runtime, which is where the stream is later polled. - let reader = - iceberg::arrow::ArrowReaderBuilder::new(file_io, IcebergRuntime::new(get_runtime())) - .with_batch_size(batch_size) - .with_data_file_concurrency_limit(self.data_file_concurrency_limit) - .with_row_selection_enabled(true) - .with_metadata_size_hint(512 * 1024) // Same as DataFusion's default - .build(); + // iceberg-rust's ArrowReader spawns IO/CPU work onto an iceberg::Runtime, which only needs + // a tokio handle. execute() runs on the JVM-called thread outside any tokio context, so we + // enter Comet's global runtime to capture its handle (this is where the stream is later + // polled). Capturing the handle rather than borrowing the runtime keeps it tear-downable + // via release_runtime. + let iceberg_runtime = { + let handle = get_runtime(); + let _guard = handle.enter(); + IcebergRuntime::try_current().map_err(|e| { + DataFusionError::Execution(format!("Failed to build Iceberg runtime: {e}")) + })? + }; + let reader = iceberg::arrow::ArrowReaderBuilder::new(file_io, iceberg_runtime) + .with_batch_size(batch_size) + .with_data_file_concurrency_limit(self.data_file_concurrency_limit) + .with_row_selection_enabled(true) + .with_metadata_size_hint(512 * 1024) // Same as DataFusion's default + .build(); // Pass all tasks to iceberg-rust at once to utilize its flatten_unordered // parallelization, avoiding overhead of single-task streams diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 7d15c761ca..48e17bb502 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -125,6 +125,12 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_init( }) } +#[no_mangle] +/// Releases the global Tokio runtime used by Comet native execution. +pub extern "system" fn Java_org_apache_comet_NativeBase_release(_e: EnvUnowned, _class: JClass) { + execution::jni_api::release_runtime(); +} + const LOG_PATTERN: &str = "{d(%y/%m/%d %H:%M:%S)} {l} {f}: {m}{n}"; /// JNI method to check if a specific feature is enabled in the native Rust code. diff --git a/spark/src/main/java/org/apache/comet/NativeBase.java b/spark/src/main/java/org/apache/comet/NativeBase.java index e2fcbb24a7..789bef8b9e 100644 --- a/spark/src/main/java/org/apache/comet/NativeBase.java +++ b/spark/src/main/java/org/apache/comet/NativeBase.java @@ -27,6 +27,7 @@ import java.io.InputStreamReader; import java.nio.file.Files; import java.nio.file.StandardCopyOption; +import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +50,7 @@ public abstract class NativeBase { private static boolean loaded = false; private static volatile Throwable loadErr = null; private static final String searchPattern = "libcomet-"; + private static final AtomicBoolean released = new AtomicBoolean(false); static { try { @@ -293,6 +295,20 @@ private static String resourceName() { */ static native void init(String logConfPath, String logLevel); + /** Release native resources through JNI */ + static native void release(); + + /** Release native resources */ + public static void releaseNative() throws Throwable { + if (!isLoaded()) { + return; + } + + if (released.compareAndSet(false, true)) { + release(); + } + } + /** * Check if a specific feature is enabled in the native library. * diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 7290ab436a..43945b97fd 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,8 +28,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.comet.{CometSparkSessionExtensions, NativeBase} import org.apache.comet.CometConf.{COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} -import org.apache.comet.CometSparkSessionExtensions /** * Comet driver plugin. This class is loaded by Spark's plugin framework. It will be instantiated @@ -95,6 +95,8 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl override def shutdown(): Unit = { logInfo("CometDriverPlugin shutdown") + NativeBase.releaseNative() + super.shutdown() } @@ -148,6 +150,24 @@ object CometDriverPlugin extends Logging { } } +class CometExecutorPlugin extends ExecutorPlugin with Logging { + + override def init(ctx: PluginContext, extraConf: ju.Map[String, String]): Unit = { + logInfo("CometExecutorPlugin init") + + super.init(ctx, extraConf) + } + + override def shutdown(): Unit = { + logInfo("CometExecutorPlugin shutdown") + + NativeBase.releaseNative() + + super.shutdown() + } + +} + /** * The Comet plugin for Spark. To enable this plugin, set the config "spark.plugins" to * `org.apache.spark.CometPlugin` @@ -155,5 +175,5 @@ object CometDriverPlugin extends Logging { class CometPlugin extends SparkPlugin with Logging { override def driverPlugin(): DriverPlugin = new CometDriverPlugin - override def executorPlugin(): ExecutorPlugin = null + override def executorPlugin(): ExecutorPlugin = new CometExecutorPlugin }