diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java index 503ef8f008b..399400a1e1d 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java @@ -158,4 +158,17 @@ public void setSessionUserId(String userId) { logger.debug("Set userId for session {}: {}", binding.getArthasSessionId(), userId); } } + + /** + * 将已认证主体写入当前绑定的 Arthas session。 + * + * @param authSubject 认证主体 + */ + public void setSessionAuth(Object authSubject) { + if (binding != null && authSubject != null) { + commandExecutor.setSessionAuth(binding.getArthasSessionId(), authSubject); + logger.debug("Set auth subject for session {}: {}", + binding.getArthasSessionId(), authSubject.getClass().getSimpleName()); + } + } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java index eec38f141ee..3ca1aa98a6b 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java @@ -9,6 +9,7 @@ import com.taobao.arthas.mcp.server.session.ArthasCommandContext; import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager.CommandSessionBinding; +import com.taobao.arthas.mcp.server.util.McpAuthExtractor; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; @@ -134,7 +135,14 @@ public ArthasCommandContext createIsolatedTaskSession(String taskId) { throw new IllegalStateException("SessionManager is not available"); } CommandSessionBinding binding = sessionManager.createIsolatedTaskSession(taskId); - return new ArthasCommandContext(commandContext.getCommandExecutor(), binding); + ArthasCommandContext isolatedContext = new ArthasCommandContext(commandContext.getCommandExecutor(), binding); + if (exchange != null && exchange.getTransportContext() != null) { + Object authSubject = exchange.getTransportContext().get(McpAuthExtractor.MCP_AUTH_SUBJECT_KEY); + if (authSubject != null) { + isolatedContext.setSessionAuth(authSubject); + } + } + return isolatedContext; } @Override diff --git a/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/session/ArthasCommandContextAuthTest.java b/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/session/ArthasCommandContextAuthTest.java new file mode 100644 index 00000000000..2f2607d45c2 --- /dev/null +++ b/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/session/ArthasCommandContextAuthTest.java @@ -0,0 +1,92 @@ +package com.taobao.arthas.mcp.server.session; + +import com.taobao.arthas.mcp.server.CommandExecutor; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +class ArthasCommandContextAuthTest { + + @Test + void setSessionAuthShouldDelegateToBoundSession() { + RecordingCommandExecutor executor = new RecordingCommandExecutor(); + ArthasCommandSessionManager.CommandSessionBinding binding = + new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "arthas-session", "consumer"); + ArthasCommandContext context = new ArthasCommandContext(executor, binding); + Object authSubject = new Object(); + + context.setSessionAuth(authSubject); + + assertThat(executor.authSessionId).isEqualTo("arthas-session"); + assertThat(executor.authSubject).isSameAs(authSubject); + assertThat(executor.authCallCount).isEqualTo(1); + } + + @Test + void setSessionAuthShouldIgnoreMissingBindingOrSubject() { + RecordingCommandExecutor executor = new RecordingCommandExecutor(); + ArthasCommandContext temporaryContext = new ArthasCommandContext(executor); + ArthasCommandSessionManager.CommandSessionBinding binding = + new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "arthas-session", "consumer"); + ArthasCommandContext boundContext = new ArthasCommandContext(executor, binding); + + temporaryContext.setSessionAuth(new Object()); + boundContext.setSessionAuth(null); + + assertThat(executor.authCallCount).isZero(); + } + + private static final class RecordingCommandExecutor implements CommandExecutor { + private String authSessionId; + private Object authSubject; + private int authCallCount; + + @Override + public Map executeSync(String commandLine, long timeout, String sessionId, + Object authSubject, String userId) { + return new HashMap(); + } + + @Override + public Map executeAsync(String commandLine, String sessionId) { + return new HashMap(); + } + + @Override + public Map pullResults(String sessionId, String consumerId) { + return new HashMap(); + } + + @Override + public Map interruptJob(String sessionId) { + return new HashMap(); + } + + @Override + public Map createSession() { + Map result = new HashMap(); + result.put("sessionId", "created-session"); + result.put("consumerId", "created-consumer"); + return result; + } + + @Override + public Map closeSession(String sessionId) { + return new HashMap(); + } + + @Override + public void setSessionAuth(String sessionId, Object authSubject) { + this.authSessionId = sessionId; + this.authSubject = authSubject; + this.authCallCount++; + } + + @Override + public void setSessionUserId(String sessionId, String userId) { + } + } +} diff --git a/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContextAuthTest.java b/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContextAuthTest.java new file mode 100644 index 00000000000..7c7ed04cc55 --- /dev/null +++ b/arthas-mcp-server/src/test/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContextAuthTest.java @@ -0,0 +1,106 @@ +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.CommandExecutor; +import com.taobao.arthas.mcp.server.protocol.server.DefaultMcpTransportContext; +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import com.taobao.arthas.mcp.server.util.McpAuthExtractor; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +class DefaultCreateTaskContextAuthTest { + + @Test + void createIsolatedTaskSessionShouldApplyTransportAuthSubject() { + RecordingCommandExecutor executor = new RecordingCommandExecutor(); + ArthasCommandContext commandContext = new ArthasCommandContext(executor, + new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "main-session", "main-consumer")); + ArthasCommandSessionManager sessionManager = new ArthasCommandSessionManager(executor); + DefaultMcpTransportContext transportContext = new DefaultMcpTransportContext(); + Object authSubject = new Object(); + transportContext.put(McpAuthExtractor.MCP_AUTH_SUBJECT_KEY, authSubject); + McpNettyServerExchange exchange = + new McpNettyServerExchange("mcp-session", null, null, null, transportContext, null); + DefaultCreateTaskContext createTaskContext = new DefaultCreateTaskContext( + null, null, exchange, "mcp-session", null, null, commandContext, sessionManager); + + ArthasCommandContext isolatedContext = createTaskContext.createIsolatedTaskSession("task-1"); + + assertThat(isolatedContext.getSessionId()).isEqualTo("isolated-session"); + assertThat(executor.authSessionId).isEqualTo("isolated-session"); + assertThat(executor.authSubject).isSameAs(authSubject); + assertThat(executor.authCallCount).isEqualTo(1); + } + + @Test + void createIsolatedTaskSessionShouldNotApplyAuthWhenTransportAuthMissing() { + RecordingCommandExecutor executor = new RecordingCommandExecutor(); + ArthasCommandContext commandContext = new ArthasCommandContext(executor, + new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "main-session", "main-consumer")); + ArthasCommandSessionManager sessionManager = new ArthasCommandSessionManager(executor); + McpNettyServerExchange exchange = new McpNettyServerExchange( + "mcp-session", null, null, null, new DefaultMcpTransportContext(), null); + DefaultCreateTaskContext createTaskContext = new DefaultCreateTaskContext( + null, null, exchange, "mcp-session", null, null, commandContext, sessionManager); + + createTaskContext.createIsolatedTaskSession("task-1"); + + assertThat(executor.authCallCount).isZero(); + } + + private static final class RecordingCommandExecutor implements CommandExecutor { + private String authSessionId; + private Object authSubject; + private int authCallCount; + + @Override + public Map executeSync(String commandLine, long timeout, String sessionId, + Object authSubject, String userId) { + return new HashMap(); + } + + @Override + public Map executeAsync(String commandLine, String sessionId) { + return new HashMap(); + } + + @Override + public Map pullResults(String sessionId, String consumerId) { + return new HashMap(); + } + + @Override + public Map interruptJob(String sessionId) { + return new HashMap(); + } + + @Override + public Map createSession() { + Map result = new HashMap(); + result.put("sessionId", "isolated-session"); + result.put("consumerId", "isolated-consumer"); + return result; + } + + @Override + public Map closeSession(String sessionId) { + return new HashMap(); + } + + @Override + public void setSessionAuth(String sessionId, Object authSubject) { + this.authSessionId = sessionId; + this.authSubject = authSubject; + this.authCallCount++; + } + + @Override + public void setSessionUserId(String sessionId, String userId) { + } + } +} diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java index 2464767a3ba..856b4e6c52a 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java @@ -122,6 +122,11 @@ protected String executeStreamable(ToolContext toolContext, String commandStr, logger.info("Starting streamable execution: {}", commandStr); + // 在异步命令创建 Job 前,把 MCP 已认证主体写入当前 Arthas session。 + if (execContext.getAuthSubject() != null) { + execContext.getCommandContext().setSessionAuth(execContext.getAuthSubject()); + } + // Set userId to session before async execution for stat reporting if (execContext.getUserId() != null) { execContext.getCommandContext().setSessionUserId(execContext.getUserId()); diff --git a/core/src/test/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasToolAuthTest.java b/core/src/test/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasToolAuthTest.java new file mode 100644 index 00000000000..80557234890 --- /dev/null +++ b/core/src/test/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasToolAuthTest.java @@ -0,0 +1,51 @@ +package com.taobao.arthas.core.mcp.tool.function; + +import com.taobao.arthas.core.mcp.util.McpAuthExtractor; +import com.taobao.arthas.mcp.server.protocol.server.DefaultMcpTransportContext; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.tool.ToolContext; +import com.taobao.arthas.mcp.server.tool.ToolContextKeys; +import org.junit.Test; +import org.mockito.InOrder; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AbstractArthasToolAuthTest { + + @Test + public void executeStreamableShouldApplyAuthSubjectBeforeAsyncExecution() { + ArthasCommandContext commandContext = mock(ArthasCommandContext.class); + String command = "trace demo.MathGame run"; + Map asyncResult = new HashMap(); + asyncResult.put("success", false); + asyncResult.put("error", "start failed"); + when(commandContext.executeAsync(command)).thenReturn(asyncResult); + + DefaultMcpTransportContext transportContext = new DefaultMcpTransportContext(); + Object authSubject = new Object(); + transportContext.put(McpAuthExtractor.MCP_AUTH_SUBJECT_KEY, authSubject); + transportContext.put(McpAuthExtractor.MCP_USER_ID_KEY, "user-1"); + + Map context = new HashMap(); + context.put(ToolContextKeys.COMMAND_CONTEXT, commandContext); + context.put(ToolContextKeys.MCP_TRANSPORT_CONTEXT, transportContext); + + new TestArthasTool().callExecuteStreamable(new ToolContext(context), command); + + InOrder inOrder = inOrder(commandContext); + inOrder.verify(commandContext).setSessionAuth(authSubject); + inOrder.verify(commandContext).setSessionUserId("user-1"); + inOrder.verify(commandContext).executeAsync(command); + } + + private static final class TestArthasTool extends AbstractArthasTool { + private String callExecuteStreamable(ToolContext toolContext, String command) { + return executeStreamable(toolContext, command, 1, 10, 1000, null); + } + } +}