diff --git a/packages/tracecat-ee/tracecat_ee/agent/activities.py b/packages/tracecat-ee/tracecat_ee/agent/activities.py index d6106e4099..c0ca2959de 100644 --- a/packages/tracecat-ee/tracecat_ee/agent/activities.py +++ b/packages/tracecat-ee/tracecat_ee/agent/activities.py @@ -79,6 +79,22 @@ class BuildAgentToolDefsResult(BaseModel): scopes: dict[str, BuildToolDefsResult] +def _mcp_discovery_error_message( + *, + scope: str, + error: BaseException, +) -> str: + """Build a safe, user-facing MCP discovery error for one agent scope.""" + message = f"Failed to discover configured MCP tools for agent scope '{scope}'" + server_name = getattr(error, "server_name", None) + if isinstance(server_name, str) and server_name: + message += f" from MCP server '{server_name}'" + status_code = getattr(error, "status_code", None) + if isinstance(status_code, int): + message += f" (HTTP {status_code})" + return message + + class ToolApprovalPayload(BaseModel): tool_call_id: str tool_name: str @@ -265,14 +281,20 @@ async def _build_scope_tool_definitions( server_count=len(hydrated_servers), ) except Exception as e: + server_name = getattr(e, "server_name", None) + status_code = getattr(e, "status_code", None) logger.error( "Failed to discover user MCP tools", + scope=args.scope, + mcp_server=server_name, + http_status_code=status_code, error_type=type(e).__name__, server_count=len(hydrated_servers), + configured_mcp_servers=[cfg["name"] for cfg in http_servers], ) if args.fail_on_mcp_discovery_error: raise ApplicationError( - "Failed to discover configured MCP tools for agent scope", + _mcp_discovery_error_message(scope=args.scope, error=e), str(e), type="AgentToolDefinitionError", non_retryable=True, diff --git a/tests/unit/test_agent_activities.py b/tests/unit/test_agent_activities.py index 54ad51661a..33b585e451 100644 --- a/tests/unit/test_agent_activities.py +++ b/tests/unit/test_agent_activities.py @@ -210,7 +210,9 @@ async def mock_discover_user_mcp_tools( fail_on_error: bool = False, ) -> dict[str, Any]: discover_fail_flags.append(fail_on_error) - raise RuntimeError("server unavailable") + raise user_client.MCPToolDiscoveryError( + "broken", RuntimeError("server unavailable") + ) class _LockService: async def resolve_lock_with_bindings( @@ -260,11 +262,27 @@ async def __aexit__( assert discover_fail_flags == [True] assert exc_info.value.message == ( - "Failed to discover configured MCP tools for agent scope" + "Failed to discover configured MCP tools for agent scope 'root' " + "from MCP server 'broken'" ) assert exc_info.value.type == "AgentToolDefinitionError" assert exc_info.value.non_retryable is True + def test_mcp_discovery_error_message_includes_scope_server_and_status( + self, + ) -> None: + error = RuntimeError("server unavailable") + error.server_name = "broken" # type: ignore[attr-defined] + error.status_code = 503 # type: ignore[attr-defined] + + assert agent_activities._mcp_discovery_error_message( + scope="analyst", + error=error, + ) == ( + "Failed to discover configured MCP tools for agent scope 'analyst' " + "from MCP server 'broken' (HTTP 503)" + ) + @pytest.mark.anyio async def test_build_agent_tool_definitions_returns_partitioned_scopes( self, diff --git a/tests/unit/test_agent_mcp_user_client.py b/tests/unit/test_agent_mcp_user_client.py index ccd5255ff7..33eafcfd11 100644 --- a/tests/unit/test_agent_mcp_user_client.py +++ b/tests/unit/test_agent_mcp_user_client.py @@ -1,8 +1,14 @@ +from types import SimpleNamespace + import pytest from fastmcp.client.transports import StreamableHttpTransport from tracecat.agent.common.types import MCPHttpServerConfig, MCPToolDefinition -from tracecat.agent.mcp.user_client import UserMCPClient, _create_transport +from tracecat.agent.mcp.user_client import ( + MCPToolDiscoveryError, + UserMCPClient, + _create_transport, +) def _mcp_server(name: str) -> MCPHttpServerConfig: @@ -66,11 +72,47 @@ async def fake_discover_server_tools( client = UserMCPClient([_mcp_server("working"), _mcp_server("broken")]) with pytest.raises( - RuntimeError, + MCPToolDiscoveryError, match="Failed to discover tools from user MCP server 'broken'", - ): + ) as exc_info: await client.discover_tools(fail_on_error=True) + assert exc_info.value.server_name == "broken" + + +@pytest.mark.anyio +async def test_discover_tools_strict_error_includes_http_status( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeHTTPStatusError(Exception): + def __init__(self) -> None: + super().__init__("Service Unavailable") + self.response = SimpleNamespace(status_code=503) + + async def fake_discover_server_tools( + self: UserMCPClient, + server_name: str, + config: MCPHttpServerConfig, + ) -> dict[str, MCPToolDefinition]: + del self, server_name, config + raise FakeHTTPStatusError() + + monkeypatch.setattr( + UserMCPClient, + "_discover_server_tools", + fake_discover_server_tools, + ) + client = UserMCPClient([_mcp_server("broken")]) + + with pytest.raises( + MCPToolDiscoveryError, + match=r"Failed to discover tools from user MCP server 'broken' \(HTTP 503\)", + ) as exc_info: + await client.discover_tools(fail_on_error=True) + + assert exc_info.value.server_name == "broken" + assert exc_info.value.status_code == 503 + # Regression: fastmcp's StreamableHttpTransport.connect_session merges any # inbound `authorization` header (from get_http_headers) with the transport's diff --git a/tracecat/agent/mcp/user_client.py b/tracecat/agent/mcp/user_client.py index a6fc31750b..4cd50fed6a 100644 --- a/tracecat/agent/mcp/user_client.py +++ b/tracecat/agent/mcp/user_client.py @@ -23,6 +23,39 @@ from tracecat.logger import logger +def _http_status_code_from_exception(error: BaseException) -> int | None: + """Best-effort extraction of an HTTP status code from a transport error.""" + seen: set[int] = set() + current: BaseException | None = error + while current is not None and id(current) not in seen: + seen.add(id(current)) + response = getattr(current, "response", None) + status_code = getattr(response, "status_code", None) + if isinstance(status_code, int): + return status_code + status_code = getattr(current, "status_code", None) + if isinstance(status_code, int): + return status_code + current = current.__cause__ or current.__context__ + return None + + +class MCPToolDiscoveryError(RuntimeError): + """Raised when strict user MCP discovery fails for one configured server.""" + + def __init__(self, server_name: str, cause: BaseException): + self.server_name = server_name + self.status_code = _http_status_code_from_exception(cause) + + message = f"Failed to discover tools from user MCP server '{server_name}'" + if self.status_code is not None: + message += f" (HTTP {self.status_code})" + cause_message = str(cause) + if cause_message: + message += f": {cause_message}" + super().__init__(message) + + def _create_transport( url: str, transport_type: Literal["http", "sse"], @@ -101,15 +134,16 @@ async def discover_tools( server_tools = await self._discover_server_tools(server_name, config) tools.update(server_tools) except Exception as e: + status_code = _http_status_code_from_exception(e) logger.error( "Failed to discover tools from user MCP server", server_name=server_name, + http_status_code=status_code, + error_type=type(e).__name__, error=str(e), ) if fail_on_error: - raise RuntimeError( - f"Failed to discover tools from user MCP server '{server_name}'" - ) from e + raise MCPToolDiscoveryError(server_name, e) from e logger.info( "Discovered user MCP tools",