Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,17 @@ public void setSessionUserId(String userId) {
logger.debug("Set userId for session {}: {}", binding.getArthasSessionId(), userId);
}
}

/**
* 将已认证主体写入当前绑定的 Arthas session。
*
* @param authSubject 认证主体
*/
public void setSessionAuth(Object authSubject) {
if (binding != null && authSubject != null) {
commandExecutor.setSessionAuth(binding.getArthasSessionId(), authSubject);
logger.debug("Set auth subject for session {}: {}",
binding.getArthasSessionId(), authSubject.getClass().getSimpleName());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.taobao.arthas.mcp.server.session.ArthasCommandContext;
import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager;
import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager.CommandSessionBinding;
import com.taobao.arthas.mcp.server.util.McpAuthExtractor;

import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
Expand Down Expand Up @@ -134,7 +135,14 @@ public ArthasCommandContext createIsolatedTaskSession(String taskId) {
throw new IllegalStateException("SessionManager is not available");
}
CommandSessionBinding binding = sessionManager.createIsolatedTaskSession(taskId);
return new ArthasCommandContext(commandContext.getCommandExecutor(), binding);
ArthasCommandContext isolatedContext = new ArthasCommandContext(commandContext.getCommandExecutor(), binding);
if (exchange != null && exchange.getTransportContext() != null) {
Object authSubject = exchange.getTransportContext().get(McpAuthExtractor.MCP_AUTH_SUBJECT_KEY);
if (authSubject != null) {
isolatedContext.setSessionAuth(authSubject);
}
}
return isolatedContext;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.taobao.arthas.mcp.server.session;

import com.taobao.arthas.mcp.server.CommandExecutor;
import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

class ArthasCommandContextAuthTest {

@Test
void setSessionAuthShouldDelegateToBoundSession() {
RecordingCommandExecutor executor = new RecordingCommandExecutor();
ArthasCommandSessionManager.CommandSessionBinding binding =
new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "arthas-session", "consumer");
ArthasCommandContext context = new ArthasCommandContext(executor, binding);
Object authSubject = new Object();

context.setSessionAuth(authSubject);

assertThat(executor.authSessionId).isEqualTo("arthas-session");
assertThat(executor.authSubject).isSameAs(authSubject);
assertThat(executor.authCallCount).isEqualTo(1);
}

@Test
void setSessionAuthShouldIgnoreMissingBindingOrSubject() {
RecordingCommandExecutor executor = new RecordingCommandExecutor();
ArthasCommandContext temporaryContext = new ArthasCommandContext(executor);
ArthasCommandSessionManager.CommandSessionBinding binding =
new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "arthas-session", "consumer");
ArthasCommandContext boundContext = new ArthasCommandContext(executor, binding);

temporaryContext.setSessionAuth(new Object());
boundContext.setSessionAuth(null);

assertThat(executor.authCallCount).isZero();
}

private static final class RecordingCommandExecutor implements CommandExecutor {
private String authSessionId;
private Object authSubject;
private int authCallCount;

@Override
public Map<String, Object> executeSync(String commandLine, long timeout, String sessionId,
Object authSubject, String userId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> executeAsync(String commandLine, String sessionId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> pullResults(String sessionId, String consumerId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> interruptJob(String sessionId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> createSession() {
Map<String, Object> result = new HashMap<String, Object>();
result.put("sessionId", "created-session");
result.put("consumerId", "created-consumer");
return result;
}

@Override
public Map<String, Object> closeSession(String sessionId) {
return new HashMap<String, Object>();
}

@Override
public void setSessionAuth(String sessionId, Object authSubject) {
this.authSessionId = sessionId;
this.authSubject = authSubject;
this.authCallCount++;
}

@Override
public void setSessionUserId(String sessionId, String userId) {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package com.taobao.arthas.mcp.server.task;

import com.taobao.arthas.mcp.server.CommandExecutor;
import com.taobao.arthas.mcp.server.protocol.server.DefaultMcpTransportContext;
import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange;
import com.taobao.arthas.mcp.server.session.ArthasCommandContext;
import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager;
import com.taobao.arthas.mcp.server.util.McpAuthExtractor;
import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

class DefaultCreateTaskContextAuthTest {

@Test
void createIsolatedTaskSessionShouldApplyTransportAuthSubject() {
RecordingCommandExecutor executor = new RecordingCommandExecutor();
ArthasCommandContext commandContext = new ArthasCommandContext(executor,
new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "main-session", "main-consumer"));
ArthasCommandSessionManager sessionManager = new ArthasCommandSessionManager(executor);
DefaultMcpTransportContext transportContext = new DefaultMcpTransportContext();
Object authSubject = new Object();
transportContext.put(McpAuthExtractor.MCP_AUTH_SUBJECT_KEY, authSubject);
McpNettyServerExchange exchange =
new McpNettyServerExchange("mcp-session", null, null, null, transportContext, null);
DefaultCreateTaskContext createTaskContext = new DefaultCreateTaskContext(
null, null, exchange, "mcp-session", null, null, commandContext, sessionManager);

ArthasCommandContext isolatedContext = createTaskContext.createIsolatedTaskSession("task-1");

assertThat(isolatedContext.getSessionId()).isEqualTo("isolated-session");
assertThat(executor.authSessionId).isEqualTo("isolated-session");
assertThat(executor.authSubject).isSameAs(authSubject);
assertThat(executor.authCallCount).isEqualTo(1);
}

@Test
void createIsolatedTaskSessionShouldNotApplyAuthWhenTransportAuthMissing() {
RecordingCommandExecutor executor = new RecordingCommandExecutor();
ArthasCommandContext commandContext = new ArthasCommandContext(executor,
new ArthasCommandSessionManager.CommandSessionBinding("mcp-session", "main-session", "main-consumer"));
ArthasCommandSessionManager sessionManager = new ArthasCommandSessionManager(executor);
McpNettyServerExchange exchange = new McpNettyServerExchange(
"mcp-session", null, null, null, new DefaultMcpTransportContext(), null);
DefaultCreateTaskContext createTaskContext = new DefaultCreateTaskContext(
null, null, exchange, "mcp-session", null, null, commandContext, sessionManager);

createTaskContext.createIsolatedTaskSession("task-1");

assertThat(executor.authCallCount).isZero();
}

private static final class RecordingCommandExecutor implements CommandExecutor {
private String authSessionId;
private Object authSubject;
private int authCallCount;

@Override
public Map<String, Object> executeSync(String commandLine, long timeout, String sessionId,
Object authSubject, String userId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> executeAsync(String commandLine, String sessionId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> pullResults(String sessionId, String consumerId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> interruptJob(String sessionId) {
return new HashMap<String, Object>();
}

@Override
public Map<String, Object> createSession() {
Map<String, Object> result = new HashMap<String, Object>();
result.put("sessionId", "isolated-session");
result.put("consumerId", "isolated-consumer");
return result;
}

@Override
public Map<String, Object> closeSession(String sessionId) {
return new HashMap<String, Object>();
}

@Override
public void setSessionAuth(String sessionId, Object authSubject) {
this.authSessionId = sessionId;
this.authSubject = authSubject;
this.authCallCount++;
}

@Override
public void setSessionUserId(String sessionId, String userId) {
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ protected String executeStreamable(ToolContext toolContext, String commandStr,

logger.info("Starting streamable execution: {}", commandStr);

// 在异步命令创建 Job 前,把 MCP 已认证主体写入当前 Arthas session。
if (execContext.getAuthSubject() != null) {
execContext.getCommandContext().setSessionAuth(execContext.getAuthSubject());
}

// Set userId to session before async execution for stat reporting
if (execContext.getUserId() != null) {
execContext.getCommandContext().setSessionUserId(execContext.getUserId());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.taobao.arthas.core.mcp.tool.function;

import com.taobao.arthas.core.mcp.util.McpAuthExtractor;
import com.taobao.arthas.mcp.server.protocol.server.DefaultMcpTransportContext;
import com.taobao.arthas.mcp.server.session.ArthasCommandContext;
import com.taobao.arthas.mcp.server.tool.ToolContext;
import com.taobao.arthas.mcp.server.tool.ToolContextKeys;
import org.junit.Test;
import org.mockito.InOrder;

import java.util.HashMap;
import java.util.Map;

import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class AbstractArthasToolAuthTest {

@Test
public void executeStreamableShouldApplyAuthSubjectBeforeAsyncExecution() {
ArthasCommandContext commandContext = mock(ArthasCommandContext.class);
String command = "trace demo.MathGame run";
Map<String, Object> asyncResult = new HashMap<String, Object>();
asyncResult.put("success", false);
asyncResult.put("error", "start failed");
when(commandContext.executeAsync(command)).thenReturn(asyncResult);

DefaultMcpTransportContext transportContext = new DefaultMcpTransportContext();
Object authSubject = new Object();
transportContext.put(McpAuthExtractor.MCP_AUTH_SUBJECT_KEY, authSubject);
transportContext.put(McpAuthExtractor.MCP_USER_ID_KEY, "user-1");

Map<String, Object> context = new HashMap<String, Object>();
context.put(ToolContextKeys.COMMAND_CONTEXT, commandContext);
context.put(ToolContextKeys.MCP_TRANSPORT_CONTEXT, transportContext);

new TestArthasTool().callExecuteStreamable(new ToolContext(context), command);

InOrder inOrder = inOrder(commandContext);
inOrder.verify(commandContext).setSessionAuth(authSubject);
inOrder.verify(commandContext).setSessionUserId("user-1");
inOrder.verify(commandContext).executeAsync(command);
}

private static final class TestArthasTool extends AbstractArthasTool {
private String callExecuteStreamable(ToolContext toolContext, String command) {
return executeStreamable(toolContext, command, 1, 10, 1000, null);
}
}
}
Loading