From 84519c44e294ca3bb91d0ed996e23505da25bcba Mon Sep 17 00:00:00 2001 From: TP Date: Mon, 29 Jun 2026 19:18:38 +0800 Subject: [PATCH] Harden frame processing on large data Previously, frame encoding and decoding are done against an in-memory byte array. This is simple, but may cause issues with very large amount of data, since the frame protocol allows 2^32 bytes of data per frame with the potential to clog the entire JVM. This uses the MessagePack library's MessageBuffer helper to encode to and decode from a MessagePack message into multiple lazy buffers, converting each buffer to a byte array separately on demand to reduce peak memory usage. I also cleaned up some abstractions since they are already pretty empty prior to this change. --- .../org/apache/airflow/sdk/execution/Comm.kt | 81 +++++++++++++------ .../org/apache/airflow/sdk/execution/Frame.kt | 35 +++++--- .../apache/airflow/sdk/execution/CommTest.kt | 8 +- 3 files changed, 87 insertions(+), 37 deletions(-) diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt index 5d917af3291c8..a96b77a9fbfdb 100644 --- a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt @@ -23,23 +23,59 @@ import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel import io.ktor.utils.io.readByteArray import io.ktor.utils.io.writeByteArray +import kotlinx.coroutines.runBlocking import org.apache.airflow.sdk.ApiError import org.apache.airflow.sdk.Bundle import org.apache.airflow.sdk.execution.comm.ErrorResponse import org.apache.airflow.sdk.execution.comm.StartupDetails +import org.msgpack.core.buffer.MessageBuffer +import org.msgpack.core.buffer.MessageBufferInput +import java.io.IOException import kotlin.concurrent.atomics.AtomicInt import kotlin.concurrent.atomics.ExperimentalAtomicApi +/** + * A [MessageBufferInput] that feeds a MessageUnpacker in chunks. + * + * Exactly [declaredLength] bytes is fed from [reader] in each chunk. + * Heap use is bounded by [CHUNK_SIZE], so a frame larger than + * [Int.MAX_VALUE] (which the protocol permits but a single ByteArray + * cannot hold) still decodes. + * + * The MessageBufferInput contract is synchronous while the underlying + * ktor read suspends, so each [next] bridges with [runBlocking]. This + * is fine since we only use this class with `Dispatchers.IO`, which is + * capable of blocking. + */ +private class ChannelFrameInput( + private val reader: ByteReadChannel, + declaredLength: UInt, +) : MessageBufferInput { + private companion object { + const val CHUNK_SIZE = (64 * 1024).toLong() + } + + private var remaining = declaredLength.toLong() + + override fun next(): MessageBuffer? { + if (remaining == 0L) return null + val want = minOf(remaining, CHUNK_SIZE).toInt() + val bytes = runBlocking { reader.readByteArray(want) } + if (bytes.size != want) { + throw IOException("Truncated frame: expected $want more bytes, got ${bytes.size}") + } + remaining -= want + return MessageBuffer.wrap(bytes) + } + + override fun close() {} // No cleanup here. The caller owns the channel's lifecycle. +} + data class IncomingFrame( val id: Int, val body: Any?, ) -data class OutgoingFrame( - val id: Int, - val body: Any, -) - @OptIn(ExperimentalAtomicApi::class) class CoordinatorComm( private val bundle: Bundle, @@ -48,10 +84,6 @@ class CoordinatorComm( ) { internal companion object { private val logger = Logger(CoordinatorComm::class) - - fun encode(outgoing: OutgoingFrame) = Frame.encodeRequest(outgoing.id, outgoing.body) - - fun decode(bytes: ByteArray) = Frame.decode(bytes) } private val nextId = AtomicInt(0) @@ -72,17 +104,18 @@ class CoordinatorComm( return } - val payloadLength = Frame.parseLengthPrefix(prefix) - val payload = reader.readByteArray(payloadLength) - if (payload.size != payloadLength) { // Something is terribly wrong. Let's bail. - logger.error( - "Payload length not right", - mapOf("expect" to payloadLength, "receive" to payload.size), - ) - shutDownRequested = true - return - } - val frame = decode(payload) + val declaredLength = Frame.parseLengthPrefix(prefix) + val frame = + try { + Frame.decode(ChannelFrameInput(reader, declaredLength)) + } catch (e: Exception) { + logger.error( + "Failed to read or decode frame", + mapOf("length" to declaredLength, "exception" to e), + ) + shutDownRequested = true + return + } logger.debug("Handling", mapOf("id" to frame.id)) handle(frame) } @@ -91,10 +124,12 @@ class CoordinatorComm( id: Int, body: Any, ) { - val data = encode(OutgoingFrame(id, body)) + val buffers = Frame.encodeRequest(id, body) logger.debug("Sending", mapOf("id" to id, "body" to body)) - writer.writeByteArray(Frame.lengthPrefix(data.size)) - writer.writeByteArray(data) + writer.writeByteArray(Frame.lengthPrefix(Frame.payloadLength(buffers))) + for (buffer in buffers) { + writer.writeByteArray(buffer.toByteArray()) + } } suspend fun handleIncoming(frame: IncomingFrame) { diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Frame.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Frame.kt index a3815d6140141..278516dd3e050 100644 --- a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Frame.kt +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Frame.kt @@ -26,9 +26,13 @@ import com.fasterxml.jackson.databind.util.StdDateFormat import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule import org.apache.airflow.sdk.execution.comm.Discriminator import org.msgpack.core.MessagePack -import java.io.ByteArrayOutputStream +import org.msgpack.core.MessageUnpacker +import org.msgpack.core.buffer.MessageBuffer +import org.msgpack.core.buffer.MessageBufferInput object Frame { + private const val MAX_FRAME_LENGTH = 0xFFFF_FFFFL + private val mapper = ObjectMapper().apply { configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) @@ -41,23 +45,23 @@ object Frame { fun encodeRequest( id: Int, body: Any, - ): ByteArray = encodeFrame(id, body) + ): List = encodeFrame(id, body) + + fun decode(input: MessageBufferInput): IncomingFrame = MessagePack.newDefaultUnpacker(input).use { decodeFrom(it) } - fun decode(bytes: ByteArray): IncomingFrame { - val unpacker = MessagePack.newDefaultUnpacker(bytes) + private fun decodeFrom(unpacker: MessageUnpacker): IncomingFrame { val headerSize = unpacker.unpackArrayHeader() check(headerSize >= 1) { "Unexpected Task SDK frame arity $headerSize" } val id = unpacker.unpackInt() val rawBody = if (headerSize >= 2) unpacker.unpackAny() else null val rawError = if (headerSize >= 3) unpacker.unpackAny() else null - unpacker.close() val body = decodeMessage(rawError) ?: decodeMessage(rawBody) return IncomingFrame(id, body) } - fun lengthPrefix(length: Int) = + fun lengthPrefix(length: UInt) = byteArrayOf( (length shr 24).toByte(), (length shr 16).toByte(), @@ -65,22 +69,29 @@ object Frame { length.toByte(), ) - fun parseLengthPrefix(prefix: ByteArray): Int { + fun payloadLength(buffers: List): UInt { + val total = buffers.sumOf { it.size().toLong() } + require(total <= MAX_FRAME_LENGTH) { + "Frame payload $total bytes exceeds protocol maximum $MAX_FRAME_LENGTH" + } + return total.toUInt() + } + + fun parseLengthPrefix(prefix: ByteArray): UInt { check(prefix.size == 4) { "Need 4 prefix bytes" } - return prefix.fold(0) { acc, byte -> (acc shl 8) or (byte.toInt() and 0xff) } + return prefix.fold(0u) { acc, byte -> (acc shl 8) or (byte.toUInt() and 0xffu) } } private fun encodeFrame( id: Int, body: Any?, - ): ByteArray { - val payload = ByteArrayOutputStream() - val packer = MessagePack.newDefaultPacker(payload) + ): List { + val packer = MessagePack.newDefaultBufferPacker() packer.packArrayHeader(2) packer.packInt(id) packer.packAny(body?.let(::toBody)) packer.close() - return payload.toByteArray() + return packer.toBufferList() } private fun decodeMessage(raw: Any?): Any? { diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt index 30a5db40a3389..37b9f271d66cc 100644 --- a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt @@ -23,6 +23,7 @@ import org.apache.airflow.sdk.execution.comm.StartupDetails import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Test +import org.msgpack.core.buffer.ArrayBufferInput import java.time.OffsetDateTime import java.time.ZoneOffset @@ -62,7 +63,7 @@ class CommsTest { 6f 82 a4 6e 61 6d 65 a8 61 6e 79 2d 6e 61 6d 65 a7 76 65 72 73 69 6f 6e ab 61 6e 79 2d 76 65 72 73 69 6f 6e b2 73 65 6e 74 72 79 5f 69 6e 74 65 67 72 61 74 69 6f 6e a0 c0 """.trimIndent() - val result = CoordinatorComm.decode(byteArrayFromHexString(data)) + val result = Frame.decode(ArrayBufferInput(byteArrayFromHexString(data))) Assertions.assertInstanceOf(IncomingFrame::class.java, result) Assertions.assertInstanceOf(StartupDetails::class.java, result.body) } @@ -71,7 +72,10 @@ class CommsTest { @DisplayName("Should serialize all fields") fun shouldEncodeSucceedTask() { val endDate = OffsetDateTime.of(2024, 12, 1, 1, 0, 0, 0, ZoneOffset.UTC) - val bytes = CoordinatorComm.encode(OutgoingFrame(3, TaskResult.success(endDate = endDate))) + val bytes = + Frame + .encodeRequest(3, TaskResult.success(endDate = endDate)) + .fold(ByteArray(0)) { acc, buffer -> acc + buffer.toByteArray() } val actual = bytes.toHexString(HexFormat { bytes { byteSeparator = " " } }) val expected =