Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,59 @@ import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import io.ktor.utils.io.readByteArray
import io.ktor.utils.io.writeByteArray
import kotlinx.coroutines.runBlocking
import org.apache.airflow.sdk.ApiError
import org.apache.airflow.sdk.Bundle
import org.apache.airflow.sdk.execution.comm.ErrorResponse
import org.apache.airflow.sdk.execution.comm.StartupDetails
import org.msgpack.core.buffer.MessageBuffer
import org.msgpack.core.buffer.MessageBufferInput
import java.io.IOException
import kotlin.concurrent.atomics.AtomicInt
import kotlin.concurrent.atomics.ExperimentalAtomicApi

/**
* A [MessageBufferInput] that feeds a MessageUnpacker in chunks.
*
* Exactly [declaredLength] bytes is fed from [reader] in each chunk.
* Heap use is bounded by [CHUNK_SIZE], so a frame larger than
* [Int.MAX_VALUE] (which the protocol permits but a single ByteArray
* cannot hold) still decodes.
*
* The MessageBufferInput contract is synchronous while the underlying
* ktor read suspends, so each [next] bridges with [runBlocking]. This
* is fine since we only use this class with `Dispatchers.IO`, which is
* capable of blocking.
*/
private class ChannelFrameInput(
private val reader: ByteReadChannel,
declaredLength: UInt,
) : MessageBufferInput {
private companion object {
const val CHUNK_SIZE = (64 * 1024).toLong()
}

private var remaining = declaredLength.toLong()

override fun next(): MessageBuffer? {
if (remaining == 0L) return null
val want = minOf(remaining, CHUNK_SIZE).toInt()
val bytes = runBlocking { reader.readByteArray(want) }
if (bytes.size != want) {
throw IOException("Truncated frame: expected $want more bytes, got ${bytes.size}")
}
remaining -= want
return MessageBuffer.wrap(bytes)
}

override fun close() {} // No cleanup here. The caller owns the channel's lifecycle.
}

data class IncomingFrame(
val id: Int,
val body: Any?,
)

data class OutgoingFrame(
val id: Int,
val body: Any,
)

@OptIn(ExperimentalAtomicApi::class)
class CoordinatorComm(
private val bundle: Bundle,
Expand All @@ -48,10 +84,6 @@ class CoordinatorComm(
) {
internal companion object {
private val logger = Logger(CoordinatorComm::class)

fun encode(outgoing: OutgoingFrame) = Frame.encodeRequest(outgoing.id, outgoing.body)

fun decode(bytes: ByteArray) = Frame.decode(bytes)
}

private val nextId = AtomicInt(0)
Expand All @@ -72,17 +104,18 @@ class CoordinatorComm(
return
}

val payloadLength = Frame.parseLengthPrefix(prefix)
val payload = reader.readByteArray(payloadLength)
if (payload.size != payloadLength) { // Something is terribly wrong. Let's bail.
logger.error(
"Payload length not right",
mapOf("expect" to payloadLength, "receive" to payload.size),
)
shutDownRequested = true
return
}
val frame = decode(payload)
val declaredLength = Frame.parseLengthPrefix(prefix)
val frame =
try {
Frame.decode(ChannelFrameInput(reader, declaredLength))
} catch (e: Exception) {
logger.error(
"Failed to read or decode frame",
mapOf("length" to declaredLength, "exception" to e),
)
shutDownRequested = true
return
}
Comment on lines +109 to +118

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try {
Frame.decode(ChannelFrameInput(reader, declaredLength))
} catch (e: Exception) {
logger.error(
"Failed to read or decode frame",
mapOf("length" to declaredLength, "exception" to e),
)
shutDownRequested = true
return
}
try {
Frame.decode(ChannelFrameInput(reader, declaredLength))
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
logger.error(
"Failed to read or decode frame",
mapOf("length" to declaredLength, "exception" to e),
)
shutDownRequested = true
return
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(needs import kotlinx.coroutines.CancellationException)

catch (e: Exception) also catches CancellationException, so a cancellation during decode becomes a clean shutdown instead of propagating. Minor and
pre-existing in this file, but the new broad catch makes it worth rethrowing.

logger.debug("Handling", mapOf("id" to frame.id))
handle(frame)
}
Expand All @@ -91,10 +124,12 @@ class CoordinatorComm(
id: Int,
body: Any,
) {
val data = encode(OutgoingFrame(id, body))
val buffers = Frame.encodeRequest(id, body)
logger.debug("Sending", mapOf("id" to id, "body" to body))
writer.writeByteArray(Frame.lengthPrefix(data.size))
writer.writeByteArray(data)
writer.writeByteArray(Frame.lengthPrefix(Frame.payloadLength(buffers)))
for (buffer in buffers) {
writer.writeByteArray(buffer.toByteArray())
}
}

suspend fun handleIncoming(frame: IncomingFrame) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ import com.fasterxml.jackson.databind.util.StdDateFormat
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import org.apache.airflow.sdk.execution.comm.Discriminator
import org.msgpack.core.MessagePack
import java.io.ByteArrayOutputStream
import org.msgpack.core.MessageUnpacker
import org.msgpack.core.buffer.MessageBuffer
import org.msgpack.core.buffer.MessageBufferInput

object Frame {
private const val MAX_FRAME_LENGTH = 0xFFFF_FFFFL

private val mapper =
ObjectMapper().apply {
configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
Expand All @@ -41,46 +45,53 @@ object Frame {
fun encodeRequest(
id: Int,
body: Any,
): ByteArray = encodeFrame(id, body)
): List<MessageBuffer> = encodeFrame(id, body)

fun decode(input: MessageBufferInput): IncomingFrame = MessagePack.newDefaultUnpacker(input).use { decodeFrom(it) }

fun decode(bytes: ByteArray): IncomingFrame {
val unpacker = MessagePack.newDefaultUnpacker(bytes)
private fun decodeFrom(unpacker: MessageUnpacker): IncomingFrame {
val headerSize = unpacker.unpackArrayHeader()
check(headerSize >= 1) { "Unexpected Task SDK frame arity $headerSize" }

val id = unpacker.unpackInt()
val rawBody = if (headerSize >= 2) unpacker.unpackAny() else null
val rawError = if (headerSize >= 3) unpacker.unpackAny() else null
unpacker.close()

val body = decodeMessage(rawError) ?: decodeMessage(rawBody)
return IncomingFrame(id, body)
}

fun lengthPrefix(length: Int) =
fun lengthPrefix(length: UInt) =
byteArrayOf(
(length shr 24).toByte(),
(length shr 16).toByte(),
(length shr 8).toByte(),
length.toByte(),
)

fun parseLengthPrefix(prefix: ByteArray): Int {
fun payloadLength(buffers: List<MessageBuffer>): UInt {
val total = buffers.sumOf { it.size().toLong() }
require(total <= MAX_FRAME_LENGTH) {
"Frame payload $total bytes exceeds protocol maximum $MAX_FRAME_LENGTH"
}
return total.toUInt()
}

fun parseLengthPrefix(prefix: ByteArray): UInt {
check(prefix.size == 4) { "Need 4 prefix bytes" }
return prefix.fold(0) { acc, byte -> (acc shl 8) or (byte.toInt() and 0xff) }
return prefix.fold(0u) { acc, byte -> (acc shl 8) or (byte.toUInt() and 0xffu) }
}

private fun encodeFrame(
id: Int,
body: Any?,
): ByteArray {
val payload = ByteArrayOutputStream()
val packer = MessagePack.newDefaultPacker(payload)
): List<MessageBuffer> {
val packer = MessagePack.newDefaultBufferPacker()
packer.packArrayHeader(2)
packer.packInt(id)
packer.packAny(body?.let(::toBody))
packer.close()
return payload.toByteArray()
return packer.toBufferList()
}

private fun decodeMessage(raw: Any?): Any? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.airflow.sdk.execution.comm.StartupDetails
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.DisplayName
import org.junit.jupiter.api.Test
import org.msgpack.core.buffer.ArrayBufferInput
import java.time.OffsetDateTime
import java.time.ZoneOffset

Expand Down Expand Up @@ -62,7 +63,7 @@ class CommsTest {
6f 82 a4 6e 61 6d 65 a8 61 6e 79 2d 6e 61 6d 65 a7 76 65 72 73 69 6f 6e ab 61 6e 79 2d 76 65 72
73 69 6f 6e b2 73 65 6e 74 72 79 5f 69 6e 74 65 67 72 61 74 69 6f 6e a0 c0
""".trimIndent()
val result = CoordinatorComm.decode(byteArrayFromHexString(data))
val result = Frame.decode(ArrayBufferInput(byteArrayFromHexString(data)))
Assertions.assertInstanceOf(IncomingFrame::class.java, result)
Assertions.assertInstanceOf(StartupDetails::class.java, result.body)
}
Expand All @@ -71,7 +72,10 @@ class CommsTest {
@DisplayName("Should serialize all fields")
fun shouldEncodeSucceedTask() {
val endDate = OffsetDateTime.of(2024, 12, 1, 1, 0, 0, 0, ZoneOffset.UTC)
val bytes = CoordinatorComm.encode(OutgoingFrame(3, TaskResult.success(endDate = endDate)))
val bytes =
Frame
.encodeRequest(3, TaskResult.success(endDate = endDate))
.fold(ByteArray(0)) { acc, buffer -> acc + buffer.toByteArray() }
val actual = bytes.toHexString(HexFormat { bytes { byteSeparator = " " } })

val expected =
Expand Down