diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java index d55789adb5..8787065bdb 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandler.java @@ -31,6 +31,8 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; @@ -60,6 +62,7 @@ public class McpStreamableHttpRequestHandler { public static final String UTF_8 = "UTF-8"; public static final String APPLICATION_JSON = "application/json"; public static final String TEXT_EVENT_STREAM = "text/event-stream"; + private static final String SSE_KEEPALIVE_COMMENT = ": keepalive\n\n"; private static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; /** @@ -84,6 +87,8 @@ public class McpStreamableHttpRequestHandler { private McpTransportContextExtractor contextExtractor; + private final Duration keepAliveInterval; + /** * Flag indicating if the transport is shutting down. */ @@ -114,6 +119,7 @@ public McpStreamableHttpRequestHandler(ObjectMapper objectMapper, String mcpEndp this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; + this.keepAliveInterval = keepAliveInterval; if (keepAliveInterval != null) { this.keepAliveScheduler = KeepAliveScheduler @@ -451,7 +457,7 @@ private void handlePostRequest(ChannelHandlerContext ctx, FullHttpRequest reques ctx.writeAndFlush(response); NettyStreamableMcpSessionTransport sessionTransport = new NettyStreamableMcpSessionTransport( - sessionId, ctx); + sessionId, ctx, this.keepAliveInterval); try { session.responseStream(jsonrpcRequest, sessionTransport, transportContext) @@ -565,12 +571,22 @@ private class NettyStreamableMcpSessionTransport implements McpStreamableServerT private final String sessionId; private final ChannelHandlerContext ctx; + private final Duration responseHeartbeatInterval; private final AtomicBoolean closed = new AtomicBoolean(false); private final ReentrantLock lock = new ReentrantLock(); + private volatile ScheduledFuture heartbeatTask; NettyStreamableMcpSessionTransport(String sessionId, ChannelHandlerContext ctx) { + this(sessionId, ctx, null); + } + + NettyStreamableMcpSessionTransport(String sessionId, ChannelHandlerContext ctx, + Duration responseHeartbeatInterval) { this.sessionId = sessionId; this.ctx = ctx; + this.responseHeartbeatInterval = responseHeartbeatInterval; + startResponseHeartbeat(); + this.ctx.channel().closeFuture().addListener(future -> cancelResponseHeartbeat()); logger.debug("Streamable session transport {} initialized", sessionId); } @@ -627,6 +643,7 @@ public CompletableFuture closeGracefully() { public void close() { lock.lock(); try { + cancelResponseHeartbeat(); if (this.closed.get()) { logger.debug("Session transport {} already closed", this.sessionId); return; @@ -664,6 +681,44 @@ private void sendSseEvent(String eventType, String data, String id) { logger.debug("SSE event sent - Type: {}, ID: {}, Data length: {}", eventType, id, data != null ? data.length() : 0); } + + private void startResponseHeartbeat() { + if (this.responseHeartbeatInterval == null || this.responseHeartbeatInterval.isZero() + || this.responseHeartbeatInterval.isNegative()) { + return; + } + long intervalMillis = Math.max(1L, this.responseHeartbeatInterval.toMillis()); + this.heartbeatTask = this.ctx.executor().scheduleAtFixedRate( + this::sendResponseHeartbeat, intervalMillis, intervalMillis, TimeUnit.MILLISECONDS); + } + + private void sendResponseHeartbeat() { + if (this.closed.get() || !this.ctx.channel().isActive()) { + return; + } + lock.lock(); + try { + if (this.closed.get() || !this.ctx.channel().isActive()) { + return; + } + ByteBuf buffer = Unpooled.copiedBuffer(SSE_KEEPALIVE_COMMENT, CharsetUtil.UTF_8); + this.ctx.writeAndFlush(new DefaultHttpContent(buffer)); + logger.trace("SSE heartbeat sent for session {}", this.sessionId); + } catch (Exception e) { + logger.warn("Failed to send SSE heartbeat for session {}: {}", this.sessionId, e.getMessage()); + this.ctx.close(); + } finally { + lock.unlock(); + } + } + + private void cancelResponseHeartbeat() { + ScheduledFuture task = this.heartbeatTask; + if (task != null) { + this.heartbeatTask = null; + task.cancel(false); + } + } } public static Builder builder() { diff --git a/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandlerTest.java b/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandlerTest.java new file mode 100644 index 0000000000..226bfd8c0b --- /dev/null +++ b/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/protocol/server/handler/McpStreamableHttpRequestHandlerTest.java @@ -0,0 +1,197 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.protocol.server.handler; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.taobao.arthas.mcp.server.CommandExecutor; +import com.taobao.arthas.mcp.server.protocol.server.McpRequestHandler; +import com.taobao.arthas.mcp.server.protocol.spec.DefaultMcpStreamableServerSessionFactory; +import com.taobao.arthas.mcp.server.protocol.spec.HttpHeaders; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; + +class McpStreamableHttpRequestHandlerTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final String MCP_ENDPOINT = "/mcp"; + + @Test + void shouldSendHeartbeatCommentWhilePostResponseIsPending() throws Exception { + CompletableFuture pendingToolCall = new CompletableFuture<>(); + McpStreamableHttpRequestHandler handler = newHandler(Duration.ofMillis(10), pendingToolCall); + String sessionId = initializeSession(handler); + + EmbeddedChannel channel = newChannel(handler); + channel.writeInbound(postRequest(sessionId, new McpSchema.JSONRPCRequest( + McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_TOOLS_CALL, + "call-1", + new McpSchema.CallToolRequest("slow", Collections.emptyMap(), null)))); + + Object headers = channel.readOutbound(); + assertThat(headers).isInstanceOf(HttpResponse.class); + ReferenceCountUtil.release(headers); + Object immediateContent = channel.readOutbound(); + assertThat(immediateContent).isNull(); + + channel.advanceTimeBy(11, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + HttpContent heartbeat = readOutbound(channel, HttpContent.class); + assertThat(heartbeat.content().toString(CharsetUtil.UTF_8)).isEqualTo(": keepalive\n\n"); + ReferenceCountUtil.release(heartbeat); + + channel.close(); + channel.finishAndReleaseAll(); + } + + private static McpStreamableHttpRequestHandler newHandler(Duration keepAliveInterval, + CompletableFuture pendingToolCall) { + McpStreamableHttpRequestHandler handler = McpStreamableHttpRequestHandler.builder() + .objectMapper(OBJECT_MAPPER) + .mcpEndpoint(MCP_ENDPOINT) + .keepAliveInterval(keepAliveInterval) + .build(); + + Map> requestHandlers = new HashMap<>(); + McpRequestHandler toolCallHandler = (exchange, commandContext, params) -> pendingToolCall; + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolCallHandler); + + handler.setSessionFactory(new DefaultMcpStreamableServerSessionFactory( + Duration.ofSeconds(30), + initializeRequest -> CompletableFuture.completedFuture(new McpSchema.InitializeResult( + initializeRequest.getProtocolVersion(), + McpSchema.ServerCapabilities.builder().build(), + new McpSchema.Implementation("test-server", "1.0.0"), + null)), + requestHandlers, + Collections.emptyMap(), + new StubCommandExecutor(), + null, + null)); + return handler; + } + + private static String initializeSession(McpStreamableHttpRequestHandler handler) throws Exception { + EmbeddedChannel channel = newChannel(handler); + McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest( + "2024-11-05", + new McpSchema.ClientCapabilities(null, null, null, null), + new McpSchema.Implementation("test-client", "1.0.0")); + + channel.writeInbound(postRequest(null, new McpSchema.JSONRPCRequest( + McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_INITIALIZE, + "init-1", + initializeRequest))); + + FullHttpResponse response = readOutbound(channel, FullHttpResponse.class); + String sessionId = response.headers().get(HttpHeaders.MCP_SESSION_ID); + assertThat(sessionId).isNotBlank(); + ReferenceCountUtil.release(response); + channel.finishAndReleaseAll(); + return sessionId; + } + + private static EmbeddedChannel newChannel(McpStreamableHttpRequestHandler handler) { + return new EmbeddedChannel(new SimpleChannelInboundHandler(false) { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) throws Exception { + handler.handle(ctx, request); + } + }); + } + + private static DefaultFullHttpRequest postRequest(String sessionId, McpSchema.JSONRPCMessage message) + throws Exception { + byte[] body = OBJECT_MAPPER.writeValueAsBytes(message); + DefaultFullHttpRequest request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.POST, + MCP_ENDPOINT, + Unpooled.wrappedBuffer(body)); + request.headers().set(HttpHeaderNames.ACCEPT, "application/json, text/event-stream"); + request.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json"); + request.headers().set(HttpHeaderNames.CONTENT_LENGTH, request.content().readableBytes()); + if (sessionId != null) { + request.headers().set(HttpHeaders.MCP_SESSION_ID, sessionId); + } + return request; + } + + private static T readOutbound(EmbeddedChannel channel, Class type) { + Object message = channel.readOutbound(); + assertThat(message).isInstanceOf(type); + return type.cast(message); + } + + private static final class StubCommandExecutor implements CommandExecutor { + + @Override + public Map executeSync(String commandLine, long timeout, String sessionId, Object authSubject, + String userId) { + return Collections.emptyMap(); + } + + @Override + public Map executeAsync(String commandLine, String sessionId) { + return Collections.emptyMap(); + } + + @Override + public Map pullResults(String sessionId, String consumerId) { + return Collections.emptyMap(); + } + + @Override + public Map interruptJob(String sessionId) { + return Collections.emptyMap(); + } + + @Override + public Map createSession(boolean quiet) { + Map result = new HashMap<>(); + result.put("sessionId", "arthas-session-1"); + result.put("consumerId", "consumer-1"); + return result; + } + + @Override + public Map closeSession(String sessionId) { + return Collections.emptyMap(); + } + + @Override + public void setSessionAuth(String sessionId, Object authSubject) { + } + + @Override + public void setSessionUserId(String sessionId, String userId) { + } + } +}