diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index d8e459de33..85cf93c4e2 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -40,7 +40,6 @@ import ( vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" "github.com/stacklok/toolhive/pkg/vmcp/codemode" "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" @@ -210,11 +209,6 @@ func Serve(ctx context.Context, cfg ServeConfig) error { // DynamicRegistry tracks backends for dynamic discovery in Kubernetes mode. dynamicRegistry := vmcp.NewDynamicRegistry(backends) backendRegistry := vmcp.BackendRegistry(dynamicRegistry) - - discoveryMgr, err := discovery.NewManager(agg) - if err != nil { - return fmt.Errorf("failed to create discovery manager: %w", err) - } slog.Info("dynamic backend registry enabled for Kubernetes environment") // Backend watcher for dynamic backend discovery. @@ -249,8 +243,9 @@ func Serve(ctx context.Context, cfg ServeConfig) error { slog.Info("kubernetes backend watcher started for dynamic backend discovery") } - // Create router. - rtr := vmcprouter.NewDefaultRouter() + // Workflow validation in core.New needs a non-nil Router, but the core routes per-call + // via NewSessionRouter and validation does not route — so an empty session router suffices. + rtr := vmcprouter.NewSessionRouter(&vmcp.RoutingTable{}) slog.Info(fmt.Sprintf("Setting up incoming authentication (type: %s)", vmcpCfg.IncomingAuth.Type)) @@ -462,8 +457,8 @@ func Serve(ctx context.Context, cfg ServeConfig) error { slog.Info(fmt.Sprintf("Loaded %d composite tool workflow definitions", len(workflowDefs))) } - // Create server with discovery manager, backend registry, and workflow definitions. - srv, err := vmcpserver.New(ctx, serverCfg, rtr, backendClient, discoveryMgr, backendRegistry, workflowDefs) + // Create server with the backend registry and workflow definitions. + srv, err := vmcpserver.New(ctx, serverCfg, rtr, backendClient, backendRegistry, workflowDefs) if err != nil { return fmt.Errorf("failed to create Virtual MCP Server: %w", err) } diff --git a/pkg/vmcp/core/core_vmcp.go b/pkg/vmcp/core/core_vmcp.go index fe1d381e69..616e132ae7 100644 --- a/pkg/vmcp/core/core_vmcp.go +++ b/pkg/vmcp/core/core_vmcp.go @@ -518,10 +518,9 @@ func workflowsRequireElicitation(defs map[string]*composer.WorkflowDefinition) b // health monitor is used (respects circuit breaker state). When nil, falls back // to the initial health status from the backend registry. // -// This is an intentional, temporary duplication of discovery.filterHealthyBackends -// (discovery/middleware.go:157, rules at 185-188): the discovery middleware keeps -// its own copy on the legacy server.New path until that path is removed in Phase 3 -// (#5442/#5445). Keep the include/exclude rules identical across both copies. +// This filtering previously had a second copy in the discovery middleware on the +// legacy server.New path. That path and its copy were removed in Phase 3 (#5445), +// so this is now the single source of truth for backend health filtering. func filterHealthyBackends(backends []vmcp.Backend, healthStatusProvider health.StatusProvider) []vmcp.Backend { if len(backends) == 0 { return backends diff --git a/pkg/vmcp/discovery/context.go b/pkg/vmcp/discovery/context.go deleted file mode 100644 index e356e10ca0..0000000000 --- a/pkg/vmcp/discovery/context.go +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -// Package discovery provides lazy per-user capability discovery for vMCP servers. -// -// This package handles context-based storage and retrieval of discovered backend -// capabilities within request-scoped contexts. The discovery process occurs -// asynchronously when a request arrives, and results are cached in the context -// to avoid redundant aggregation operations during request handling. -package discovery - -import ( - "context" - - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" -) - -// contextKey is an unexported type for context keys to avoid collisions. -type contextKey struct{} - -// discoveredCapabilitiesKey is the context key for storing aggregated capabilities. -var discoveredCapabilitiesKey = contextKey{} - -// WithDiscoveredCapabilities returns a new context with discovered capabilities attached. -// If capabilities is nil, the original context is returned unchanged. -func WithDiscoveredCapabilities(ctx context.Context, capabilities *aggregator.AggregatedCapabilities) context.Context { - if capabilities == nil { - return ctx - } - return context.WithValue(ctx, discoveredCapabilitiesKey, capabilities) -} - -// DiscoveredCapabilitiesFromContext retrieves discovered capabilities from the context. -// Returns (nil, false) if capabilities are not found in the context. -func DiscoveredCapabilitiesFromContext(ctx context.Context) (*aggregator.AggregatedCapabilities, bool) { - capabilities, ok := ctx.Value(discoveredCapabilitiesKey).(*aggregator.AggregatedCapabilities) - return capabilities, ok -} diff --git a/pkg/vmcp/discovery/context_test.go b/pkg/vmcp/discovery/context_test.go deleted file mode 100644 index 49fc551044..0000000000 --- a/pkg/vmcp/discovery/context_test.go +++ /dev/null @@ -1,91 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package discovery - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" -) - -func TestWithDiscoveredCapabilities(t *testing.T) { - t.Parallel() - - t.Run("no context value returns nil, false", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - retrieved, ok := DiscoveredCapabilitiesFromContext(ctx) - - assert.False(t, ok) - assert.Nil(t, retrieved) - }) - - t.Run("capabilities stored in context", func(t *testing.T) { - t.Parallel() - - caps := &aggregator.AggregatedCapabilities{ - Metadata: &aggregator.AggregationMetadata{ - BackendCount: 1, - }, - } - - ctx := context.Background() - enrichedCtx := WithDiscoveredCapabilities(ctx, caps) - - require.NotNil(t, enrichedCtx) - - // Verify we can retrieve the capabilities - retrieved, ok := DiscoveredCapabilitiesFromContext(enrichedCtx) - assert.True(t, ok) - require.NotNil(t, retrieved) - assert.Equal(t, caps, retrieved) - }) - - t.Run("nil capabilities returns original context", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - enrichedCtx := WithDiscoveredCapabilities(ctx, nil) - - // Should return original context unchanged - assert.Equal(t, ctx, enrichedCtx) - - // Attempting to retrieve should return nil, false - retrieved, ok := DiscoveredCapabilitiesFromContext(enrichedCtx) - assert.False(t, ok) - assert.Nil(t, retrieved) - }) - - t.Run("capabilities can be overwritten", func(t *testing.T) { - t.Parallel() - - caps1 := &aggregator.AggregatedCapabilities{ - Metadata: &aggregator.AggregationMetadata{ - BackendCount: 1, - }, - } - - caps2 := &aggregator.AggregatedCapabilities{ - Metadata: &aggregator.AggregationMetadata{ - BackendCount: 2, - }, - } - - ctx := context.Background() - ctx = WithDiscoveredCapabilities(ctx, caps1) - ctx = WithDiscoveredCapabilities(ctx, caps2) - - retrieved, ok := DiscoveredCapabilitiesFromContext(ctx) - assert.True(t, ok) - require.NotNil(t, retrieved) - assert.Equal(t, caps2, retrieved) - assert.NotEqual(t, caps1, retrieved) - }) -} diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go deleted file mode 100644 index 731b096faa..0000000000 --- a/pkg/vmcp/discovery/manager.go +++ /dev/null @@ -1,84 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -// Package discovery provides lazy per-user capability discovery for vMCP servers. -// -// This package implements per-request capability discovery with user-specific -// authentication context, enabling truly multi-tenant operation where different -// users may see different capabilities based on their permissions. -package discovery - -import ( - "context" - "errors" - "fmt" - "log/slog" - - "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" -) - -//go:generate mockgen -destination=mocks/mock_manager.go -package=mocks -source=manager.go Manager - -var ( - // ErrAggregatorNil is returned when aggregator is nil. - ErrAggregatorNil = errors.New("aggregator cannot be nil") - // ErrDiscoveryFailed is returned when capability discovery fails. - ErrDiscoveryFailed = errors.New("capability discovery failed") - // ErrNoIdentity is returned when user identity is not found in context. - ErrNoIdentity = errors.New("user identity not found in context") -) - -// Manager performs capability discovery with user context. -type Manager interface { - // Discover performs capability aggregation for the given backends with user context. - Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) - // Stop gracefully stops the manager and cleans up resources. - Stop() -} - -// DefaultManager is the default implementation of Manager. -type DefaultManager struct { - aggregator aggregator.Aggregator -} - -// NewManager creates a new discovery manager with the given aggregator. -func NewManager(agg aggregator.Aggregator) (Manager, error) { - if agg == nil { - return nil, ErrAggregatorNil - } - - return &DefaultManager{ - aggregator: agg, - }, nil -} - -// Discover performs capability aggregation for the given backends. -// -// Results are computed fresh on each call — no caching is performed. New MCP -// sessions are infrequent and aggregation latency is negligible compared to -// LLM round-trips, so caching adds complexity without meaningful benefit. -// -// The context must contain an authenticated user identity (set by auth middleware). -// Returns ErrNoIdentity if user identity is not found in context. -func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - // Validate user identity is present (set by auth middleware) - // This ensures discovery happens with proper user authentication context - identity, ok := auth.IdentityFromContext(ctx) - if !ok { - return nil, fmt.Errorf("%w: ensure auth middleware runs before discovery middleware", ErrNoIdentity) - } - - slog.Debug("performing capability discovery", "user", identity.Subject, "backends", len(backends)) - - caps, err := m.aggregator.AggregateCapabilities(ctx, backends) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) - } - - return caps, nil -} - -// Stop is a no-op. Retained for interface compatibility. -func (*DefaultManager) Stop() {} diff --git a/pkg/vmcp/discovery/manager_test.go b/pkg/vmcp/discovery/manager_test.go deleted file mode 100644 index 7d0063d73a..0000000000 --- a/pkg/vmcp/discovery/manager_test.go +++ /dev/null @@ -1,174 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package discovery - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - aggmocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" -) - -func TestNewManager(t *testing.T) { - t.Parallel() - - t.Run("success with valid aggregator", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAgg := aggmocks.NewMockAggregator(ctrl) - mgr, err := NewManager(mockAgg) - - require.NoError(t, err) - assert.NotNil(t, mgr) - assert.IsType(t, &DefaultManager{}, mgr) - }) - - t.Run("error with nil aggregator", func(t *testing.T) { - t.Parallel() - - mgr, err := NewManager(nil) - - require.Error(t, err) - assert.Nil(t, mgr) - assert.ErrorIs(t, err, ErrAggregatorNil) - }) -} - -func TestDefaultManager_Discover(t *testing.T) { - t.Parallel() - - t.Run("successful discovery", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAgg := aggmocks.NewMockAggregator(ctrl) - backends := []vmcp.Backend{newTestBackend("backend1")} - expectedCaps := &aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, - } - - mockAgg.EXPECT(). - AggregateCapabilities(gomock.Any(), backends). - Return(expectedCaps, nil) - - mgr, err := NewManager(mockAgg) - require.NoError(t, err) - defer mgr.Stop() - - // Create context with user identity - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user123", Name: "Test User"}} - ctx := auth.WithIdentity(context.Background(), identity) - - caps, err := mgr.Discover(ctx, backends) - - require.NoError(t, err) - assert.Equal(t, expectedCaps, caps) - }) - - t.Run("error when user identity missing from context", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAgg := aggmocks.NewMockAggregator(ctrl) - backends := []vmcp.Backend{newTestBackend("backend1")} - - // No expectation on mockAgg - should fail before calling aggregator - - mgr, err := NewManager(mockAgg) - require.NoError(t, err) - - // Use context without user identity - caps, err := mgr.Discover(context.Background(), backends) - - require.Error(t, err) - assert.Nil(t, caps) - assert.ErrorIs(t, err, ErrNoIdentity) - assert.Contains(t, err.Error(), "ensure auth middleware runs before discovery middleware") - }) - - t.Run("discovery failure from aggregator", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAgg := aggmocks.NewMockAggregator(ctrl) - backends := []vmcp.Backend{ - newTestBackend("backend1"), - } - - expectedErr := errors.New("aggregation failed: connection timeout") - - mockAgg.EXPECT(). - AggregateCapabilities(gomock.Any(), backends). - Return(nil, expectedErr) - - mgr, err := NewManager(mockAgg) - require.NoError(t, err) - defer mgr.Stop() - - // Create context with user identity - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user456"}} - ctx := auth.WithIdentity(context.Background(), identity) - - caps, err := mgr.Discover(ctx, backends) - - require.Error(t, err) - assert.Nil(t, caps) - assert.ErrorIs(t, err, ErrDiscoveryFailed) - }) -} - -func TestDefaultManager_Stop(t *testing.T) { - t.Parallel() - - t.Run("stop is safe to call", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockAgg := aggmocks.NewMockAggregator(ctrl) - - mgr, err := NewManager(mockAgg) - require.NoError(t, err) - - // Stop should be a no-op and not panic - mgr.Stop() - // Calling Stop multiple times should also be safe - mgr.Stop() - }) -} - -// Test helpers - -func newTestBackend(id string) vmcp.Backend { - return vmcp.Backend{ - ID: id, - Name: id, - BaseURL: "http://localhost:8080", - TransportType: "streamable-http", - HealthStatus: vmcp.BackendHealthy, - } -} - -//nolint:unparam // name parameter kept for flexibility in future tests -func newTestTool(name, backendID string) vmcp.Tool { - return vmcp.Tool{ - Name: name, - Description: name + " description", - InputSchema: map[string]any{"type": "object"}, - BackendID: backendID, - } -} diff --git a/pkg/vmcp/discovery/middleware.go b/pkg/vmcp/discovery/middleware.go deleted file mode 100644 index 5ebdf7f555..0000000000 --- a/pkg/vmcp/discovery/middleware.go +++ /dev/null @@ -1,327 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -// Package discovery provides lazy per-user capability discovery for vMCP servers. -// -// Capabilities are discovered at session initialization and cached in the session for -// its lifetime. This ensures deterministic behavior and prevents notification spam from -// redundant capability updates when backends haven't changed. -// -// For MultiSession requests, the middleware injects routing context from the session's -// routing table so that composite tool workflow steps can route backend tool calls correctly. -// Tool routing for non-composite tools is handled by session-scoped handlers registered -// with AddSessionTools. -// -// Future enhancement: Add manager-level capability cache to share discoveries across -// sessions, plus separate background refresh worker (not in middleware request path) -// that periodically rediscovers capabilities, detects changes via hash comparison, and -// pushes updates to active sessions via MCP tools/list_changed notifications. Middleware -// flow remains unchanged - still just retrieves from session cache on subsequent requests. -package discovery - -import ( - "context" - "errors" - "fmt" - "log/slog" - "net/http" - "time" - - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - "github.com/stacklok/toolhive/pkg/vmcp/health" - vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" -) - -const ( - // discoveryTimeout is the maximum time for capability discovery. - discoveryTimeout = 15 * time.Second -) - -// MultiSessionGetter retrieves a fully-formed MultiSession by session ID. -// Returns (nil, false) if the session does not exist or has not yet been initialized. -// This interface decouples the discovery middleware from the concrete session manager. -type MultiSessionGetter interface { - GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool) -} - -// middlewareConfig holds optional configuration for Middleware. -type middlewareConfig struct { - sessionScopedRouting bool - timeout time.Duration -} - -// MiddlewareOption configures Middleware behaviour. -type MiddlewareOption func(*middlewareConfig) - -// WithSessionScopedRouting disables backend capability discovery for any request -// that arrives without an Mcp-Session-Id header (i.e. initialize requests). -// Use this when tools are registered per-session via AddSessionTools rather -// than through the discovery pipeline. -func WithSessionScopedRouting() MiddlewareOption { - return func(c *middlewareConfig) { - c.sessionScopedRouting = true - } -} - -// WithDiscoveryTimeout overrides the default discovery timeout. -func WithDiscoveryTimeout(timeout time.Duration) MiddlewareOption { - return func(c *middlewareConfig) { - c.timeout = timeout - } -} - -// Middleware performs capability discovery on session initialization and injects -// routing context for subsequent requests. Must be placed after auth middleware. -// -// Initialize requests (no session ID): discovers capabilities and stores in context. -// Subsequent requests (MultiSession): injects routing table from session into context -// so composite tool workflow steps can route backend tool calls correctly. -// -// Returns HTTP 504 for timeouts, HTTP 503 for discovery errors. -// -// The registry parameter provides the current list of backends. For dynamic environments -// (Kubernetes with DynamicRegistry), backends are fetched on each initialize request to -// ensure the latest backend list is used for capability discovery. -// -// The healthStatusProvider parameter (optional, can be nil) enables filtering backends -// based on current health status from the health monitor. When provided, only healthy and -// degraded backends are included in capability aggregation; unhealthy, unknown, and -// unauthenticated backends are excluded (which includes backends with OPEN circuit breakers). -// When nil (health monitoring disabled), the initial health status from the registry is used. -func Middleware( - manager Manager, - registry vmcp.BackendRegistry, - multiSessionGetter MultiSessionGetter, - healthStatusProvider health.StatusProvider, - opts ...MiddlewareOption, -) func(http.Handler) http.Handler { - cfg := middlewareConfig{ - timeout: discoveryTimeout, - } - for _, o := range opts { - o(&cfg) - } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - sessionID := r.Header.Get("Mcp-Session-Id") - - if sessionID == "" { - if cfg.sessionScopedRouting { - // Session-scoped routing registers capabilities via the OnRegisterSession - // hook rather than through discovery. Skip discovery on initialize. - next.ServeHTTP(w, r) - return - } - // Initialize request: discover and cache capabilities in session. - var err error - ctx, err = handleInitializeRequest(ctx, r, manager, registry, healthStatusProvider, cfg.timeout) - if err != nil { - handleDiscoveryError(w, r, err) - return - } - } else { - // Subsequent request: inject routing context if the session is ready. - ctx = handleSubsequentRequest(ctx, r, sessionID, multiSessionGetter) - } - - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -// filterHealthyBackends filters backends to only include those that are healthy -// or degraded. Backends that are unhealthy, unknown, or unauthenticated are excluded -// from capability aggregation to prevent exposing tools from unavailable backends. -// -// A note on BackendUnauthenticated: a 401/403 from a backend that has an outgoing -// auth strategy configured is treated as BackendHealthy by the health checker -// (health probes deliberately do not carry user credentials — the challenge proves -// reachability). BackendUnauthenticated therefore indicates a misconfiguration: -// the backend requires authentication but no outgoing auth strategy is configured -// on the backend target. Excluding such backends from capability aggregation is -// the correct behavior — their capabilities cannot be safely exposed. -// -// Health status filtering: -// - healthy: included (fully operational) -// - degraded: included (slow but working) -// - empty/zero-value: included (assume healthy when health monitoring is disabled) -// - unhealthy: excluded (not responding, circuit breaker may be open) -// - unknown: excluded (status not yet determined) -// - unauthenticated: excluded (misconfiguration: backend requires auth but none configured) -// -// When healthStatusProvider is provided, the current health status from the health -// monitor is used (respects circuit breaker state). When nil, falls back to the -// initial health status from the backend registry. -func filterHealthyBackends(backends []vmcp.Backend, healthStatusProvider health.StatusProvider) []vmcp.Backend { - if len(backends) == 0 { - return backends - } - - healthy := make([]vmcp.Backend, 0, len(backends)) - excluded := 0 - - for i := range backends { - backend := &backends[i] - - // Get current health status from health monitor if available - // This ensures circuit breaker state is respected during capability aggregation - var healthStatus vmcp.BackendHealthStatus - if healthStatusProvider != nil { - if status, exists := healthStatusProvider.QueryBackendStatus(backend.ID); exists { - healthStatus = status - } else { - // Backend not tracked by health monitor - use registry status - healthStatus = backend.HealthStatus - } - } else { - // Health monitoring disabled - use registry status - healthStatus = backend.HealthStatus - } - - // Include healthy, degraded, and empty/zero-value (assume healthy) backends. - // Explicitly exclude unhealthy, unknown, and unauthenticated backends. - if healthStatus == "" || - healthStatus == vmcp.BackendHealthy || - healthStatus == vmcp.BackendDegraded { - healthy = append(healthy, *backend) - } else { - excluded++ - //nolint:gosec // G706: backend fields are internal, not user-controlled - slog.Debug("excluding backend from capability aggregation due to health status", - "backend_name", backend.Name, - "backend_id", backend.ID, - "health_status", healthStatus, - "source", func() string { - if healthStatusProvider != nil { - return "health_monitor" - } - return "registry" - }()) - } - } - - if excluded > 0 { - //nolint:gosec // G706: values are internal counts, not user-controlled - slog.Debug("filtered backends for capability aggregation", - "total_backends", len(backends), - "healthy_backends", len(healthy), - "excluded_backends", excluded) - } - - return healthy -} - -// handleInitializeRequest performs capability discovery for initialize requests. -// Returns updated context with discovered capabilities or an error. -// -// For dynamic environments, backends are fetched from the registry on each request -// to ensure the latest backend list is used (e.g., when backends are added/removed). -// -// When healthStatusProvider is provided, backends are filtered based on current health -// status from the health monitor (respects circuit breaker state). When nil, the initial -// health status from the backend registry is used. -func handleInitializeRequest( - ctx context.Context, - r *http.Request, - manager Manager, - registry vmcp.BackendRegistry, - healthStatusProvider health.StatusProvider, - timeout time.Duration, -) (context.Context, error) { - discoveryCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - // Get current backend list from registry (supports dynamic backend changes) - allBackends := registry.List(discoveryCtx) - - // Filter to only include healthy/degraded backends for capability aggregation - // Uses current health status from health monitor when available - backends := filterHealthyBackends(allBackends, healthStatusProvider) - - //nolint:gosec // G706: request method/path are standard HTTP fields, not injection vectors - slog.Debug("starting capability discovery for initialize request", - "method", r.Method, - "path", r.URL.Path, - "total_backend_count", len(allBackends), - "healthy_backend_count", len(backends)) - - capabilities, err := manager.Discover(discoveryCtx, backends) - if err != nil { - //nolint:gosec // G706: request method/path are standard HTTP fields, not injection vectors - slog.Error("capability discovery failed", - "error", err, - "method", r.Method, - "path", r.URL.Path) - return ctx, fmt.Errorf("discovery failed: %w", err) - } - - //nolint:gosec // G706: request method/path are standard HTTP fields, not injection vectors - slog.Debug("capability discovery completed", - "method", r.Method, - "path", r.URL.Path, - "tool_count", len(capabilities.Tools), - "resource_count", len(capabilities.Resources), - "prompt_count", len(capabilities.Prompts)) - - return WithDiscoveredCapabilities(ctx, capabilities), nil -} - -// handleSubsequentRequest retrieves cached capabilities from the session. -// Returns the updated context; never returns an error. -func handleSubsequentRequest( - ctx context.Context, - r *http.Request, - sessionID string, - multiSessionGetter MultiSessionGetter, -) context.Context { - //nolint:gosec // G706: session ID and request fields are not injection vectors - slog.Debug("retrieving capabilities from session for subsequent request", - "session_id", sessionID, - "method", r.Method, - "path", r.URL.Path) - - // Look up the fully-formed MultiSession. Returns (nil, false) if the session does - // not exist yet or is still a placeholder (CreateSession not yet complete). In either - // case, skip capability injection and let the SDK validate/reject the request — the - // SDK's own SessionIdManager.Validate() returns 404 for unknown session IDs. - multiSess, ok := multiSessionGetter.GetMultiSession(sessionID) - if !ok { - //nolint:gosec // G706: session ID is not an injection vector - slog.Debug("session not found or still initialising, skipping capability injection", - "session_id", sessionID) - return ctx - } - - routingTable := multiSess.GetRoutingTable() - if routingTable == nil { - // Session initialisation not yet complete; no capabilities to inject. - // Composite tool calls will fail routing, but backend tool calls are - // already registered with the SDK and will succeed. - //nolint:gosec // G706: session ID is not an injection vector - slog.Debug("multi-session routing table not yet initialised; skipping capability injection", - "session_id", sessionID) - return ctx - } - //nolint:gosec // G706: session ID is not an injection vector - slog.Debug("injecting capabilities from multi-session routing table for composite tool routing", - "session_id", sessionID, - "tool_count", len(routingTable.Tools)) - capabilities := &aggregator.AggregatedCapabilities{ - RoutingTable: routingTable, - Tools: multiSess.Tools(), - } - return WithDiscoveredCapabilities(ctx, capabilities) -} - -// handleDiscoveryError writes appropriate HTTP error responses based on the error type. -func handleDiscoveryError(w http.ResponseWriter, _ *http.Request, err error) { - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { - http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout) - return - } - - // Default to service unavailable for other errors - http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) -} diff --git a/pkg/vmcp/discovery/middleware_test.go b/pkg/vmcp/discovery/middleware_test.go deleted file mode 100644 index 456fd5b751..0000000000 --- a/pkg/vmcp/discovery/middleware_test.go +++ /dev/null @@ -1,693 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package discovery - -import ( - "context" - "errors" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" - vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" - sessionmocks "github.com/stacklok/toolhive/pkg/vmcp/session/types/mocks" -) - -// Ensure stubMultiSessionGetter implements MultiSessionGetter. -var _ MultiSessionGetter = (*stubMultiSessionGetter)(nil) - -// stubMultiSessionGetter is a simple in-memory MultiSessionGetter for tests. -type stubMultiSessionGetter struct { - sessions map[string]vmcpsession.MultiSession -} - -func newStubMultiSessionGetter() *stubMultiSessionGetter { - return &stubMultiSessionGetter{sessions: make(map[string]vmcpsession.MultiSession)} -} - -func (s *stubMultiSessionGetter) GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool) { - sess, ok := s.sessions[sessionID] - return sess, ok -} - -func (s *stubMultiSessionGetter) add(sessionID string, sess vmcpsession.MultiSession) { - s.sessions[sessionID] = sess -} - -// unorderedBackendsMatcher is a gomock matcher that compares backend slices without caring about order. -// This is needed because ImmutableRegistry.List() iterates over a map which doesn't guarantee order. -type unorderedBackendsMatcher struct { - expected []vmcp.Backend -} - -func (m unorderedBackendsMatcher) Matches(x any) bool { - actual, ok := x.([]vmcp.Backend) - if !ok { - return false - } - if len(actual) != len(m.expected) { - return false - } - - // Create maps for comparison - expectedMap := make(map[string]vmcp.Backend) - for _, b := range m.expected { - expectedMap[b.ID] = b - } - - actualMap := make(map[string]vmcp.Backend) - for _, b := range actual { - actualMap[b.ID] = b - } - - // Check all expected backends are present - for id, expectedBackend := range expectedMap { - actualBackend, found := actualMap[id] - if !found { - return false - } - if expectedBackend.ID != actualBackend.ID || expectedBackend.Name != actualBackend.Name { - return false - } - } - - return true -} - -func (unorderedBackendsMatcher) String() string { - return "matches backends regardless of order" -} - -func TestMiddleware_InitializeRequest(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - { - ID: "backend1", - Name: "Backend 1", - BaseURL: "http://backend1:8080", - TransportType: "streamable-http", - HealthStatus: vmcp.BackendHealthy, - }, - } - - expectedCaps := &aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{ - {Name: "tool1", BackendID: "backend1"}, - }, - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, - RoutingTable: &vmcp.RoutingTable{ - Tools: map[string]*vmcp.BackendTarget{ - "tool1": {WorkloadID: "backend1"}, - }, - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{ - BackendCount: 1, - ToolCount: 1, - }, - } - - // Expect discovery to be called for initialize request (no session ID) - mockMgr.EXPECT(). - Discover(gomock.Any(), unorderedBackendsMatcher{backends}). - Return(expectedCaps, nil) - - // Create a test handler that verifies capabilities are in context - handlerCalled := false - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlerCalled = true - - // Verify capabilities are in context - caps, ok := DiscoveredCapabilitiesFromContext(r.Context()) - assert.True(t, ok, "capabilities should be in context") - assert.NotNil(t, caps, "capabilities should not be nil") - assert.Equal(t, expectedCaps, caps, "capabilities should match expected") - - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("success")) - }) - - // Wrap handler with middleware - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil) - wrappedHandler := middleware(testHandler) - - // Create initialize request (no session ID header) - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/initialize", nil) - rec := httptest.NewRecorder() - - // Execute request - wrappedHandler.ServeHTTP(rec, req) - - // Verify response - assert.True(t, handlerCalled, "handler should have been called") - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "success", rec.Body.String()) -} - -func TestMiddleware_SubsequentRequest_SkipsDiscovery(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - { - ID: "backend1", - Name: "Backend 1", - BaseURL: "http://backend1:8080", - TransportType: "streamable-http", - HealthStatus: vmcp.BackendHealthy, - }, - } - - // NO EXPECTATION for Discover - it should not be called for subsequent requests - // If Discover is called, the test will fail due to unexpected call - - handlerCalled := false - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlerCalled = true - - // Verify capabilities ARE in context (retrieved from session, not discovered) - caps, ok := DiscoveredCapabilitiesFromContext(r.Context()) - assert.True(t, ok, "capabilities should be in context from session") - assert.NotNil(t, caps, "capabilities should not be nil") - assert.NotNil(t, caps.RoutingTable, "routing table should not be nil") - assert.Len(t, caps.RoutingTable.Tools, 1, "should have 1 tool from session") - - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("success")) - }) - - // Create a routing table for this session - routingTable := &vmcp.RoutingTable{ - Tools: map[string]*vmcp.BackendTarget{"tool1": {WorkloadID: "backend1"}}, - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - } - - // Add a MockMultiSession with the routing table - mockSess := sessionmocks.NewMockMultiSession(ctrl) - mockSess.EXPECT().GetRoutingTable().Return(routingTable).AnyTimes() - mockSess.EXPECT().Tools().Return(nil).AnyTimes() - - sessionMgr := newStubMultiSessionGetter() - sessionMgr.add("dddddddd-1001-1001-1001-000000000001", mockSess) - - // Wrap handler with middleware - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, sessionMgr, nil) - wrappedHandler := middleware(testHandler) - - // Create subsequent request (with session ID header) - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/tools/list", nil) - req.Header.Set("Mcp-Session-Id", "dddddddd-1001-1001-1001-000000000001") - rec := httptest.NewRecorder() - - // Execute request - wrappedHandler.ServeHTTP(rec, req) - - // Verify response - assert.True(t, handlerCalled, "handler should have been called") - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "success", rec.Body.String()) -} - -func TestMiddleware_DiscoveryTimeout(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - } - - // Simulate timeout by returning context.DeadlineExceeded - mockMgr.EXPECT(). - Discover(gomock.Any(), backends). - Return(nil, context.DeadlineExceeded) - - handlerCalled := false - testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - handlerCalled = true - w.WriteHeader(http.StatusOK) - }) - - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil) - wrappedHandler := middleware(testHandler) - - // Initialize request (no session ID) - discovery should happen - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/initialize", nil) - rec := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(rec, req) - - // Verify timeout response - assert.False(t, handlerCalled, "handler should not be called on timeout") - assert.Equal(t, http.StatusGatewayTimeout, rec.Code) - body, _ := io.ReadAll(rec.Body) - assert.Contains(t, string(body), http.StatusText(http.StatusGatewayTimeout)) -} - -func TestMiddleware_DiscoveryFailure(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - } - - // Simulate non-timeout error - discoveryErr := errors.New("backend connection failed") - mockMgr.EXPECT(). - Discover(gomock.Any(), backends). - Return(nil, discoveryErr) - - handlerCalled := false - testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - handlerCalled = true - w.WriteHeader(http.StatusOK) - }) - - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil) - wrappedHandler := middleware(testHandler) - - // Initialize request (no session ID) - discovery should happen - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/initialize", nil) - rec := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(rec, req) - - // Verify service unavailable response - assert.False(t, handlerCalled, "handler should not be called on failure") - assert.Equal(t, http.StatusServiceUnavailable, rec.Code) - body, _ := io.ReadAll(rec.Body) - assert.Contains(t, string(body), http.StatusText(http.StatusServiceUnavailable)) -} - -func TestMiddleware_CapabilitiesInContext(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendHealthy}, - } - - expectedCaps := &aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{ - {Name: "tool1", BackendID: "backend1"}, - {Name: "tool2", BackendID: "backend2"}, - }, - Resources: []vmcp.Resource{ - {URI: "test://resource1", BackendID: "backend1"}, - }, - Prompts: []vmcp.Prompt{ - {Name: "prompt1", BackendID: "backend2"}, - }, - SupportsLogging: true, - SupportsSampling: false, - RoutingTable: &vmcp.RoutingTable{ - Tools: map[string]*vmcp.BackendTarget{ - "tool1": {WorkloadID: "backend1"}, - "tool2": {WorkloadID: "backend2"}, - }, - Resources: map[string]*vmcp.BackendTarget{ - "test://resource1": {WorkloadID: "backend1"}, - }, - Prompts: map[string]*vmcp.BackendTarget{ - "prompt1": {WorkloadID: "backend2"}, - }, - }, - Metadata: &aggregator.AggregationMetadata{ - BackendCount: 2, - ToolCount: 2, - ResourceCount: 1, - PromptCount: 1, - }, - } - - mockMgr.EXPECT(). - Discover(gomock.Any(), unorderedBackendsMatcher{backends}). - Return(expectedCaps, nil) - - // Create handler that inspects context in detail - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - caps, ok := DiscoveredCapabilitiesFromContext(r.Context()) - require.True(t, ok, "capabilities must be in context") - require.NotNil(t, caps, "capabilities must not be nil") - - // Verify all fields are accessible - assert.Len(t, caps.Tools, 2) - assert.Equal(t, "tool1", caps.Tools[0].Name) - assert.Equal(t, "tool2", caps.Tools[1].Name) - - assert.Len(t, caps.Resources, 1) - assert.Equal(t, "test://resource1", caps.Resources[0].URI) - - assert.Len(t, caps.Prompts, 1) - assert.Equal(t, "prompt1", caps.Prompts[0].Name) - - assert.True(t, caps.SupportsLogging) - assert.False(t, caps.SupportsSampling) - - assert.NotNil(t, caps.RoutingTable) - assert.Contains(t, caps.RoutingTable.Tools, "tool1") - assert.Contains(t, caps.RoutingTable.Tools, "tool2") - assert.Contains(t, caps.RoutingTable.Resources, "test://resource1") - assert.Contains(t, caps.RoutingTable.Prompts, "prompt1") - - assert.Equal(t, 2, caps.Metadata.BackendCount) - assert.Equal(t, 2, caps.Metadata.ToolCount) - assert.Equal(t, 1, caps.Metadata.ResourceCount) - assert.Equal(t, 1, caps.Metadata.PromptCount) - - w.WriteHeader(http.StatusOK) - }) - - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil) - wrappedHandler := middleware(testHandler) - - // Initialize request (no session ID) - discovery should happen - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/initialize", nil) - rec := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) -} - -func TestMiddleware_PreservesUserContext(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - } - - expectedCaps := &aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{ - {Name: "tool1", BackendID: "backend1"}, - }, - RoutingTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{ - BackendCount: 1, - ToolCount: 1, - }, - } - - // Define the key type - type userIDKey string - - mockMgr.EXPECT(). - Discover(gomock.Any(), backends). - DoAndReturn(func(ctx context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - // Verify user context is passed through - userID := ctx.Value(userIDKey("user_id")) - assert.Equal(t, "test_user", userID) - return expectedCaps, nil - }) - - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify user context is preserved after middleware - userID := r.Context().Value(userIDKey("user_id")) - assert.Equal(t, "test_user", userID, "user context should be preserved") - - // Verify capabilities are also in context - caps, ok := DiscoveredCapabilitiesFromContext(r.Context()) - assert.True(t, ok) - assert.NotNil(t, caps) - - w.WriteHeader(http.StatusOK) - }) - - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil) - wrappedHandler := middleware(testHandler) - - // Create initialize request with user context (as auth middleware would) - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/initialize", nil) - ctx := context.WithValue(req.Context(), userIDKey("user_id"), "test_user") - req = req.WithContext(ctx) - - rec := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) -} - -func TestMiddleware_ContextTimeoutHandling(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := mocks.NewMockManager(ctrl) - - backends := []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - } - - testTimeout := 100 * time.Millisecond - - // Simulate slow discovery that takes longer than timeout - mockMgr.EXPECT(). - Discover(gomock.Any(), backends). - DoAndReturn(func(ctx context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - // Verify timeout context is set - deadline, ok := ctx.Deadline() - assert.True(t, ok, "context should have a deadline") - assert.True(t, time.Until(deadline) <= testTimeout, "timeout should be set correctly") - - // Simulate slow operation that exceeds the timeout - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(5 * time.Second): - return nil, errors.New("operation completed without timeout") - } - }) - - testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - backendRegistry := vmcp.NewImmutableRegistry(backends) - middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil, WithDiscoveryTimeout(testTimeout)) - wrappedHandler := middleware(testHandler) - - // Initialize request (no session ID) - discovery should happen - req := httptest.NewRequest(http.MethodPost, "/mcp/v1/initialize", nil) - rec := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(rec, req) - - // Verify timeout response (should be 504 Gateway Timeout) - assert.Equal(t, http.StatusGatewayTimeout, rec.Code) -} - -func TestFilterHealthyBackends(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - backends []vmcp.Backend - expectedBackends []string // backend IDs that should be included - }{ - { - name: "empty backends list", - backends: []vmcp.Backend{}, - expectedBackends: []string{}, - }, - { - name: "all healthy backends", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend3", Name: "Backend 3", HealthStatus: vmcp.BackendHealthy}, - }, - expectedBackends: []string{"backend1", "backend2", "backend3"}, - }, - { - name: "all unhealthy backends", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendUnhealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendUnhealthy}, - }, - expectedBackends: []string{}, - }, - { - name: "mixed healthy and unhealthy backends", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendUnhealthy}, - {ID: "backend3", Name: "Backend 3", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend4", Name: "Backend 4", HealthStatus: vmcp.BackendUnhealthy}, - }, - expectedBackends: []string{"backend1", "backend3"}, - }, - { - name: "include degraded backends", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendDegraded}, - {ID: "backend3", Name: "Backend 3", HealthStatus: vmcp.BackendUnhealthy}, - }, - expectedBackends: []string{"backend1", "backend2"}, - }, - { - name: "exclude unknown status backends", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendUnknown}, - {ID: "backend3", Name: "Backend 3", HealthStatus: vmcp.BackendHealthy}, - }, - expectedBackends: []string{"backend1", "backend3"}, - }, - { - name: "exclude unauthenticated backends", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendUnauthenticated}, - }, - expectedBackends: []string{"backend1"}, - }, - { - name: "include backends with empty/zero-value health status (assume healthy)", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1"}, // No HealthStatus set (zero value = "") - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend3", Name: "Backend 3"}, // No HealthStatus set - }, - expectedBackends: []string{"backend1", "backend2", "backend3"}, - }, - { - name: "all status types", - backends: []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendDegraded}, - {ID: "backend3", Name: "Backend 3", HealthStatus: vmcp.BackendUnhealthy}, - {ID: "backend4", Name: "Backend 4", HealthStatus: vmcp.BackendUnknown}, - {ID: "backend5", Name: "Backend 5", HealthStatus: vmcp.BackendUnauthenticated}, - }, - expectedBackends: []string{"backend1", "backend2"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - // Test with nil health status provider (health monitoring disabled) - // This tests the fallback to registry-based health status - result := filterHealthyBackends(tt.backends, nil) - - assert.Equal(t, len(tt.expectedBackends), len(result), "unexpected number of backends returned") - - // Verify only expected backends are included - resultIDs := make([]string, len(result)) - for i, backend := range result { - resultIDs[i] = backend.ID - } - assert.ElementsMatch(t, tt.expectedBackends, resultIDs, "unexpected backends in result") - - // Verify all returned backends have healthy, degraded, or empty (assume healthy) status - for _, backend := range result { - assert.True(t, - backend.HealthStatus == "" || - backend.HealthStatus == vmcp.BackendHealthy || - backend.HealthStatus == vmcp.BackendDegraded, - "backend %s has unexpected status: %s", backend.ID, backend.HealthStatus) - } - }) - } -} - -// TestFilterHealthyBackends_WithHealthMonitor verifies that filterHealthyBackends -// uses the health status provider when available, overriding registry health status. -func TestFilterHealthyBackends_WithHealthMonitor(t *testing.T) { - t.Parallel() - - // Create backends with "healthy" status in registry - backends := []vmcp.Backend{ - {ID: "backend1", Name: "Backend 1", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend2", Name: "Backend 2", HealthStatus: vmcp.BackendHealthy}, - {ID: "backend3", Name: "Backend 3", HealthStatus: vmcp.BackendHealthy}, - } - - // Create mock health status provider that overrides health status - mockHealthProvider := &mockHealthStatusProvider{ - statuses: map[string]vmcp.BackendHealthStatus{ - "backend1": vmcp.BackendHealthy, // Healthy in both registry and monitor - "backend2": vmcp.BackendUnhealthy, // Healthy in registry, unhealthy in monitor (circuit breaker OPEN) - // backend3 not in monitor - should use registry status (healthy) - }, - } - - // Filter with health monitor - result := filterHealthyBackends(backends, mockHealthProvider) - - // Should include backend1 (healthy in monitor) and backend3 (not monitored, falls back to registry) - // Should exclude backend2 (unhealthy in monitor, circuit breaker may be OPEN) - assert.Equal(t, 2, len(result), "expected 2 backends (backend1 and backend3)") - - resultIDs := make([]string, len(result)) - for i, backend := range result { - resultIDs[i] = backend.ID - } - assert.ElementsMatch(t, []string{"backend1", "backend3"}, resultIDs, - "expected backend1 and backend3 to be included") -} - -// mockHealthStatusProvider is a test helper that implements health.StatusProvider -type mockHealthStatusProvider struct { - statuses map[string]vmcp.BackendHealthStatus -} - -func (m *mockHealthStatusProvider) QueryBackendStatus(backendID string) (vmcp.BackendHealthStatus, bool) { - status, exists := m.statuses[backendID] - return status, exists -} diff --git a/pkg/vmcp/discovery/mocks/mock_manager.go b/pkg/vmcp/discovery/mocks/mock_manager.go deleted file mode 100644 index 06f24778c8..0000000000 --- a/pkg/vmcp/discovery/mocks/mock_manager.go +++ /dev/null @@ -1,70 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: manager.go -// -// Generated by this command: -// -// mockgen -destination=mocks/mock_manager.go -package=mocks -source=manager.go Manager -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - vmcp "github.com/stacklok/toolhive/pkg/vmcp" - aggregator "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - gomock "go.uber.org/mock/gomock" -) - -// MockManager is a mock of Manager interface. -type MockManager struct { - ctrl *gomock.Controller - recorder *MockManagerMockRecorder - isgomock struct{} -} - -// MockManagerMockRecorder is the mock recorder for MockManager. -type MockManagerMockRecorder struct { - mock *MockManager -} - -// NewMockManager creates a new mock instance. -func NewMockManager(ctrl *gomock.Controller) *MockManager { - mock := &MockManager{ctrl: ctrl} - mock.recorder = &MockManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockManager) EXPECT() *MockManagerMockRecorder { - return m.recorder -} - -// Discover mocks base method. -func (m *MockManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Discover", ctx, backends) - ret0, _ := ret[0].(*aggregator.AggregatedCapabilities) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Discover indicates an expected call of Discover. -func (mr *MockManagerMockRecorder) Discover(ctx, backends any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Discover", reflect.TypeOf((*MockManager)(nil).Discover), ctx, backends) -} - -// Stop mocks base method. -func (m *MockManager) Stop() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Stop") -} - -// Stop indicates an expected call of Stop. -func (mr *MockManagerMockRecorder) Stop() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop)) -} diff --git a/pkg/vmcp/router/default_router.go b/pkg/vmcp/router/default_router.go deleted file mode 100644 index 3e6f3bfba7..0000000000 --- a/pkg/vmcp/router/default_router.go +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package router - -import ( - "context" - "fmt" - "log/slog" - - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" -) - -// defaultRouter is a stateless router implementation that retrieves routing -// information from the request context. With lazy discovery, capabilities are -// discovered per-request and stored in context by the discovery middleware. -// -// This router is thread-safe by design since it maintains no mutable state. -type defaultRouter struct { - // No fields - routing table comes from request context -} - -// NewDefaultRouter creates a new default router instance. -func NewDefaultRouter() Router { - return &defaultRouter{} -} - -// routeCapability is a generic helper that implements the common routing logic -// for tools, resources, and prompts. It extracts capabilities from context, -// validates the routing table, and looks up the key in the specified map. -// -// Parameters: -// - ctx: The request context containing discovered capabilities -// - key: The capability identifier (tool name, resource URI, or prompt name) -// - getMap: Function to extract the specific map from the routing table -// - mapName: Name of the map for error messages (e.g., "tools", "resources", "prompts") -// - entityType: Type of entity for log messages (e.g., "tool", "resource", "prompt") -// - notFoundErr: The specific error to wrap when the key is not found -func routeCapability( - ctx context.Context, - key string, - getMap func(*vmcp.RoutingTable) map[string]*vmcp.BackendTarget, - mapName string, - entityType string, - notFoundErr error, -) (*vmcp.BackendTarget, error) { - // Defensive nil check - prevent panic if context is nil - if ctx == nil { - return nil, fmt.Errorf("context cannot be nil") - } - - // Get capabilities from context (set by discovery middleware) - capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) - if !ok || capabilities == nil { - return nil, fmt.Errorf("capabilities not found in context - discovery middleware may not have run") - } - - if capabilities.RoutingTable == nil { - return nil, fmt.Errorf("routing table not initialized in discovered capabilities") - } - - capabilityMap := getMap(capabilities.RoutingTable) - if capabilityMap == nil { - return nil, fmt.Errorf("routing table %s map not initialized", mapName) - } - - target, exists := capabilityMap[key] - if !exists { - slog.Debug("not found in routing table", "type", entityType, "key", key) - return nil, fmt.Errorf("%w: %s", notFoundErr, key) - } - - slog.Debug("routed capability to backend", "type", entityType, "key", key, "backend", target.WorkloadID) - return target, nil -} - -// RouteTool resolves a tool name to its backend target. -// With lazy discovery, this method gets capabilities from the request context -// instead of using a cached routing table. -func (*defaultRouter) RouteTool(ctx context.Context, toolName string) (*vmcp.BackendTarget, error) { - return routeCapability( - ctx, - toolName, - func(rt *vmcp.RoutingTable) map[string]*vmcp.BackendTarget { return rt.Tools }, - "tools", - "Tool", - ErrToolNotFound, - ) -} - -// ResolveToolName returns toolName unchanged. The defaultRouter has no static -// routing table, so dot-convention resolution is not available; the caller -// should already be using resolved names when working with this router. -func (*defaultRouter) ResolveToolName(_ context.Context, toolName string) string { - return toolName -} - -// RouteResource resolves a resource URI to its backend target. -// With lazy discovery, this method gets capabilities from the request context -// instead of using a cached routing table. -func (*defaultRouter) RouteResource(ctx context.Context, uri string) (*vmcp.BackendTarget, error) { - return routeCapability( - ctx, - uri, - func(rt *vmcp.RoutingTable) map[string]*vmcp.BackendTarget { return rt.Resources }, - "resources", - "Resource", - ErrResourceNotFound, - ) -} - -// RoutePrompt resolves a prompt name to its backend target. -// With lazy discovery, this method gets capabilities from the request context -// instead of using a cached routing table. -func (*defaultRouter) RoutePrompt(ctx context.Context, name string) (*vmcp.BackendTarget, error) { - return routeCapability( - ctx, - name, - func(rt *vmcp.RoutingTable) map[string]*vmcp.BackendTarget { return rt.Prompts }, - "prompts", - "Prompt", - ErrPromptNotFound, - ) -} diff --git a/pkg/vmcp/router/default_router_test.go b/pkg/vmcp/router/default_router_test.go deleted file mode 100644 index 9292c944e3..0000000000 --- a/pkg/vmcp/router/default_router_test.go +++ /dev/null @@ -1,339 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package router_test - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" - "github.com/stacklok/toolhive/pkg/vmcp/router" -) - -func TestDefaultRouter_RouteTool(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - setupTable *vmcp.RoutingTable - toolName string - expectedID string - expectError bool - errorContains string - }{ - { - name: "route to existing tool", - setupTable: &vmcp.RoutingTable{ - Tools: map[string]*vmcp.BackendTarget{ - "test_tool": { - WorkloadID: "backend1", - WorkloadName: "Backend 1", - BaseURL: "http://backend1:8080", - }, - }, - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - toolName: "test_tool", - expectedID: "backend1", - expectError: false, - }, - { - name: "tool not found", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - toolName: "nonexistent_tool", - expectError: true, - errorContains: "tool not found", - }, - { - name: "capabilities not in context", - setupTable: nil, - toolName: "test_tool", - expectError: true, - errorContains: "capabilities not found in context", - }, - { - name: "routing table tools map is nil", - setupTable: &vmcp.RoutingTable{ - Tools: nil, // nil map - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - toolName: "test_tool", - expectError: true, - errorContains: "routing table tools map not initialized", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - r := router.NewDefaultRouter() - - // Setup routing table in context if provided - if tt.setupTable != nil { - caps := &aggregator.AggregatedCapabilities{ - RoutingTable: tt.setupTable, - } - ctx = discovery.WithDiscoveredCapabilities(ctx, caps) - } - - // Test routing - target, err := r.RouteTool(ctx, tt.toolName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorContains) - assert.Nil(t, target) - } else { - require.NoError(t, err) - require.NotNil(t, target) - assert.Equal(t, tt.expectedID, target.WorkloadID) - } - }) - } -} - -func TestDefaultRouter_RouteResource(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - setupTable *vmcp.RoutingTable - uri string - expectedID string - expectError bool - errorContains string - }{ - { - name: "route to existing resource", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: map[string]*vmcp.BackendTarget{ - "file:///path/to/resource": { - WorkloadID: "backend2", - WorkloadName: "Backend 2", - BaseURL: "http://backend2:8080", - }, - }, - Prompts: make(map[string]*vmcp.BackendTarget), - }, - uri: "file:///path/to/resource", - expectedID: "backend2", - expectError: false, - }, - { - name: "resource not found", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - uri: "file:///nonexistent", - expectError: true, - errorContains: "resource not found", - }, - { - name: "capabilities not in context", - setupTable: nil, - uri: "file:///test", - expectError: true, - errorContains: "capabilities not found in context", - }, - { - name: "routing table resources map is nil", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: nil, // nil map - Prompts: make(map[string]*vmcp.BackendTarget), - }, - uri: "file:///test", - expectError: true, - errorContains: "routing table resources map not initialized", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - r := router.NewDefaultRouter() - - // Setup routing table in context if provided - if tt.setupTable != nil { - caps := &aggregator.AggregatedCapabilities{ - RoutingTable: tt.setupTable, - } - ctx = discovery.WithDiscoveredCapabilities(ctx, caps) - } - - // Test routing - target, err := r.RouteResource(ctx, tt.uri) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorContains) - assert.Nil(t, target) - } else { - require.NoError(t, err) - require.NotNil(t, target) - assert.Equal(t, tt.expectedID, target.WorkloadID) - } - }) - } -} - -func TestDefaultRouter_RoutePrompt(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - setupTable *vmcp.RoutingTable - promptName string - expectedID string - expectError bool - errorContains string - }{ - { - name: "route to existing prompt", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: map[string]*vmcp.BackendTarget{ - "greeting": { - WorkloadID: "backend3", - WorkloadName: "Backend 3", - BaseURL: "http://backend3:8080", - }, - }, - }, - promptName: "greeting", - expectedID: "backend3", - expectError: false, - }, - { - name: "prompt not found", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - promptName: "nonexistent", - expectError: true, - errorContains: "prompt not found", - }, - { - name: "capabilities not in context", - setupTable: nil, - promptName: "test", - expectError: true, - errorContains: "capabilities not found in context", - }, - { - name: "routing table prompts map is nil", - setupTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: nil, // nil map - }, - promptName: "test", - expectError: true, - errorContains: "routing table prompts map not initialized", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - r := router.NewDefaultRouter() - - // Setup routing table in context if provided - if tt.setupTable != nil { - caps := &aggregator.AggregatedCapabilities{ - RoutingTable: tt.setupTable, - } - ctx = discovery.WithDiscoveredCapabilities(ctx, caps) - } - - // Test routing - target, err := r.RoutePrompt(ctx, tt.promptName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorContains) - assert.Nil(t, target) - } else { - require.NoError(t, err) - require.NotNil(t, target) - assert.Equal(t, tt.expectedID, target.WorkloadID) - } - }) - } -} - -func TestDefaultRouter_ConcurrentAccess(t *testing.T) { - t.Parallel() - - // Setup routing table - table := &vmcp.RoutingTable{ - Tools: map[string]*vmcp.BackendTarget{ - "tool1": {WorkloadID: "backend1"}, - "tool2": {WorkloadID: "backend2"}, - }, - Resources: map[string]*vmcp.BackendTarget{ - "res1": {WorkloadID: "backend1"}, - }, - Prompts: map[string]*vmcp.BackendTarget{ - "prompt1": {WorkloadID: "backend2"}, - }, - } - - caps := &aggregator.AggregatedCapabilities{ - RoutingTable: table, - } - ctx := discovery.WithDiscoveredCapabilities(context.Background(), caps) - - r := router.NewDefaultRouter() - - // Run concurrent readers - router is stateless so this should be safe - const numGoroutines = 10 - const numOperations = 100 - - done := make(chan bool, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - for j := 0; j < numOperations; j++ { - _, _ = r.RouteTool(ctx, "tool1") - _, _ = r.RouteResource(ctx, "res1") - _, _ = r.RoutePrompt(ctx, "prompt1") - } - done <- true - }() - } - - // Wait for all goroutines to complete - for i := 0; i < numGoroutines; i++ { - <-done - } - - // Verify router still works correctly - target, err := r.RouteTool(ctx, "tool1") - require.NoError(t, err) - assert.Equal(t, "backend1", target.WorkloadID) -} diff --git a/pkg/vmcp/server/authz_integration_test.go b/pkg/vmcp/server/authz_integration_test.go index 762efcd84a..20a8e18405 100644 --- a/pkg/vmcp/server/authz_integration_test.go +++ b/pkg/vmcp/server/authz_integration_test.go @@ -25,7 +25,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server" @@ -44,7 +43,6 @@ func newCedarAuthzTestServer(t *testing.T, backendURL string, policies ...string ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) backend := vmcp.Backend{ @@ -55,9 +53,6 @@ func newCedarAuthzTestServer(t *testing.T, backendURL string, policies ...string } mockBackendRegistry.EXPECT().List(gomock.Any()).Return([]vmcp.Backend{backend}).AnyTimes() mockBackendRegistry.EXPECT().Get(gomock.Any(), gomock.Any()).Return(&backend).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{}, nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() authReg := vmcpauth.NewDefaultOutgoingAuthRegistry() require.NoError(t, authReg.RegisterStrategy( @@ -109,9 +104,8 @@ func newCedarAuthzTestServer(t *testing.T, backendURL string, policies ...string AuthMiddleware: identityMiddleware, Authz: authzCfg, }, - router.NewDefaultRouter(), + router.NewSessionRouter(&vmcp.RoutingTable{}), backendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) diff --git a/pkg/vmcp/server/body_limit_test.go b/pkg/vmcp/server/body_limit_test.go index 391825e850..d1762a1e72 100644 --- a/pkg/vmcp/server/body_limit_test.go +++ b/pkg/vmcp/server/body_limit_test.go @@ -14,7 +14,6 @@ import ( "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/bodylimit" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" "github.com/stacklok/toolhive/pkg/vmcp/server" @@ -32,18 +31,15 @@ func TestHandler_RejectsOversizedBody(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() srv, err := server.New( t.Context(), &server.Config{Host: "127.0.0.1", Port: 0, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) diff --git a/pkg/vmcp/server/health_monitoring_test.go b/pkg/vmcp/server/health_monitoring_test.go index 3b1253a4da..b380aa33ae 100644 --- a/pkg/vmcp/server/health_monitoring_test.go +++ b/pkg/vmcp/server/health_monitoring_test.go @@ -16,7 +16,6 @@ import ( "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/vmcp" - discoverymocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/mocks" routermocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" @@ -31,7 +30,6 @@ func TestServer_HealthMonitoring_Disabled(t *testing.T) { mockRouter := routermocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoverymocks.NewMockManager(ctrl) backends := []vmcp.Backend{ {ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080"}, @@ -48,7 +46,7 @@ func TestServer_HealthMonitoring_Disabled(t *testing.T) { } backendRegistry := vmcp.NewImmutableRegistry(backends) - srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, backendRegistry, nil) require.NoError(t, err) require.NotNil(t, srv) @@ -83,7 +81,6 @@ func TestServer_HealthMonitoring_Enabled(t *testing.T) { mockRouter := routermocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoverymocks.NewMockManager(ctrl) backends := []vmcp.Backend{ {ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080", TransportType: "sse"}, @@ -117,7 +114,7 @@ func TestServer_HealthMonitoring_Enabled(t *testing.T) { } backendRegistry := vmcp.NewImmutableRegistry(backends) - srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, backendRegistry, nil) require.NoError(t, err) require.NotNil(t, srv) @@ -125,7 +122,6 @@ func TestServer_HealthMonitoring_Enabled(t *testing.T) { assert.NotNil(t, srv.backendHealth()) // Start server in background - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -191,7 +187,6 @@ func TestServer_HealthMonitoring_StartupFailure(t *testing.T) { mockRouter := routermocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoverymocks.NewMockManager(ctrl) backends := []vmcp.Backend{ {ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080"}, @@ -213,7 +208,7 @@ func TestServer_HealthMonitoring_StartupFailure(t *testing.T) { // This should fail during New() because of invalid health monitor config backendRegistry := vmcp.NewImmutableRegistry(backends) - srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, backendRegistry, nil) require.Error(t, err) require.Nil(t, srv) assert.Contains(t, err.Error(), "failed to create health monitor") @@ -228,7 +223,6 @@ func TestServer_HandleBackendHealth_Disabled(t *testing.T) { mockRouter := routermocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoverymocks.NewMockManager(ctrl) backends := []vmcp.Backend{ {ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080"}, @@ -245,7 +239,7 @@ func TestServer_HandleBackendHealth_Disabled(t *testing.T) { } backendRegistry := vmcp.NewImmutableRegistry(backends) - srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, backendRegistry, nil) require.NoError(t, err) // Create test request @@ -277,7 +271,6 @@ func TestServer_HandleBackendHealth_Enabled(t *testing.T) { mockRouter := routermocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoverymocks.NewMockManager(ctrl) backends := []vmcp.Backend{ {ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080", TransportType: "sse"}, @@ -304,11 +297,10 @@ func TestServer_HandleBackendHealth_Enabled(t *testing.T) { } backendRegistry := vmcp.NewImmutableRegistry(backends) - srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, backendRegistry, nil) require.NoError(t, err) // Start server - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -372,7 +364,6 @@ func TestServer_Stop_StopsHealthMonitor(t *testing.T) { mockRouter := routermocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoverymocks.NewMockManager(ctrl) backends := []vmcp.Backend{ {ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080", TransportType: "sse"}, @@ -399,11 +390,10 @@ func TestServer_Stop_StopsHealthMonitor(t *testing.T) { } backendRegistry := vmcp.NewImmutableRegistry(backends) - srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + srv, err := New(context.Background(), cfg, mockRouter, mockBackendClient, backendRegistry, nil) require.NoError(t, err) // Start server - mockDiscoveryMgr.EXPECT().Stop().Times(1) ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error, 1) diff --git a/pkg/vmcp/server/health_test.go b/pkg/vmcp/server/health_test.go index 9a3ac48270..26a22794d2 100644 --- a/pkg/vmcp/server/health_test.go +++ b/pkg/vmcp/server/health_test.go @@ -17,8 +17,6 @@ import ( "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server" @@ -33,8 +31,7 @@ func createTestServer(t *testing.T) *server.Server { t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) // Find an available port for parallel test execution port := networking.FindAvailable() @@ -44,23 +41,8 @@ func createTestServer(t *testing.T) *server.Server { backends := []vmcp.Backend{} // Mock discovery manager to return empty capabilities - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{}, - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, - RoutingTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{}, - }, nil). - AnyTimes() // Mock Stop to be called during server shutdown - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() // Create context for server ctx, cancel := context.WithCancel(t.Context()) @@ -72,7 +54,7 @@ func createTestServer(t *testing.T) *server.Server { Host: "127.0.0.1", Port: port, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + }, rt, mockBackendClient, backendRegistry, nil) require.NoError(t, err) // Start server in background @@ -178,8 +160,7 @@ func TestServer_SessionManager(t *testing.T) { t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) backendRegistry := vmcp.NewImmutableRegistry([]vmcp.Backend{}) srv, err := server.New(context.Background(), &server.Config{ @@ -187,7 +168,7 @@ func TestServer_SessionManager(t *testing.T) { Version: "1.0.0", SessionTTL: 10 * time.Minute, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + }, rt, mockBackendClient, backendRegistry, nil) require.NoError(t, err) // SessionManager should be accessible @@ -202,8 +183,7 @@ func TestServer_SessionManager(t *testing.T) { t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) customTTL := 15 * time.Minute backendRegistry := vmcp.NewImmutableRegistry([]vmcp.Backend{}) @@ -212,7 +192,7 @@ func TestServer_SessionManager(t *testing.T) { Version: "1.0.0", SessionTTL: customTTL, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, backendRegistry, nil) + }, rt, mockBackendClient, backendRegistry, nil) require.NoError(t, err) mgr := srv.SessionManager() diff --git a/pkg/vmcp/server/integration_test.go b/pkg/vmcp/server/integration_test.go index e9753230bd..9b441b9795 100644 --- a/pkg/vmcp/server/integration_test.go +++ b/pkg/vmcp/server/integration_test.go @@ -26,8 +26,6 @@ import ( transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server" @@ -166,42 +164,12 @@ func TestIntegration_AggregatorToRouterToServer(t *testing.T) { assert.Equal(t, 1, len(aggregatedCaps.RoutingTable.Resources)) assert.Equal(t, 1, len(aggregatedCaps.RoutingTable.Prompts)) - // Step 4: Create router and add capabilities to context - rt := router.NewDefaultRouter() + // Step 4: Create a router for the server. Per-call routing is exercised by the + // core's SessionRouter (see router/session_router_test.go); this end-to-end test + // only needs a router instance to construct the server. + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) - // Add discovered capabilities to context - ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, aggregatedCaps) - - // Step 5: Verify router can route to correct backends (using context with capabilities) - target, err := rt.RouteTool(ctxWithCaps, "github_create_issue") - require.NoError(t, err) - assert.Equal(t, "github", target.WorkloadID) - assert.Equal(t, "http://github-mcp:8080", target.BaseURL) - - target, err = rt.RouteTool(ctxWithCaps, "jira_create_issue") - require.NoError(t, err) - assert.Equal(t, "jira", target.WorkloadID) - assert.Equal(t, "http://jira-mcp:8080", target.BaseURL) - - target, err = rt.RouteResource(ctxWithCaps, "file:///github/repos") - require.NoError(t, err) - assert.Equal(t, "github", target.WorkloadID) - - target, err = rt.RoutePrompt(ctxWithCaps, "code_review") - require.NoError(t, err) - assert.Equal(t, "github", target.WorkloadID) - - // Step 6: Create discovery manager and server - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - - // Mock discovery to return our aggregated capabilities - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Return(aggregatedCaps, nil). - AnyTimes() - - // Mock Stop to be called during server shutdown - mockDiscoveryMgr.EXPECT().Stop().Times(1) + // Step 5: Create the server srv, err := server.New(ctx, &server.Config{ Name: "test-vmcp", @@ -210,13 +178,13 @@ func TestIntegration_AggregatorToRouterToServer(t *testing.T) { Port: 4484, SessionFactory: newNoopMockFactory(t), Aggregator: agg, - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + }, rt, mockBackendClient, vmcp.NewImmutableRegistry(backends), nil) require.NoError(t, err) // Validate server address assert.Equal(t, "127.0.0.1:4484", srv.Address()) - // Step 7: Start server and validate it's running + // Step 6: Start server and validate it's running serverCtx, cancelServer := context.WithCancel(ctx) t.Cleanup(cancelServer) @@ -454,19 +422,7 @@ func TestIntegration_AuditLogging(t *testing.T) { } // Create router - rt := router.NewDefaultRouter() - - // Create discovery manager - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - resolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) - return agg.AggregateCapabilities(ctx, backends) - }). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) // Helper function to read audit log file readAuditLog := func() string { @@ -548,7 +504,7 @@ func TestIntegration_AuditLogging(t *testing.T) { AuditConfig: auditConfig, SessionFactory: auditSessionFactory, Aggregator: auditAgg, - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + }, rt, mockBackendClient, vmcp.NewImmutableRegistry(backends), nil) require.NoError(t, err) // Start server @@ -771,27 +727,11 @@ func TestIntegration_AuditLoggingWithAuth(t *testing.T) { mockBackendClient := mocks.NewMockBackendClient(ctrl) // Create mock discovery manager - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - return &aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{ - { - Name: "test_tool", - Description: "A test tool", - InputSchema: map[string]any{"type": "object"}, - }, - }, - }, nil - }). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() backends := []vmcp.Backend{} // Create router - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) // Create identity middleware for auth identityMiddleware := func(next http.Handler) http.Handler { @@ -835,7 +775,7 @@ func TestIntegration_AuditLoggingWithAuth(t *testing.T) { AuthMiddleware: identityMiddleware, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + }, rt, mockBackendClient, vmcp.NewImmutableRegistry(backends), nil) require.NoError(t, err) // Start server diff --git a/pkg/vmcp/server/readiness_test.go b/pkg/vmcp/server/readiness_test.go index badb4c6ed6..cbfc0227f4 100644 --- a/pkg/vmcp/server/readiness_test.go +++ b/pkg/vmcp/server/readiness_test.go @@ -16,8 +16,6 @@ import ( "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server" @@ -38,28 +36,11 @@ func TestReadinessEndpoint_StaticMode(t *testing.T) { t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) port := networking.FindAvailable() require.NotZero(t, port, "Failed to find available port") - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{}, - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, - RoutingTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{}, - }, nil). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() - ctx, cancel := context.WithCancel(t.Context()) // Create server without Watcher (static mode) @@ -70,7 +51,7 @@ func TestReadinessEndpoint_StaticMode(t *testing.T) { Port: port, Watcher: nil, // Static mode SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) + }, rt, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) require.NoError(t, err) t.Cleanup(cancel) @@ -111,28 +92,11 @@ func TestReadinessEndpoint_DynamicMode_CacheSynced(t *testing.T) { t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) port := networking.FindAvailable() require.NotZero(t, port, "Failed to find available port") - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{}, - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, - RoutingTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{}, - }, nil). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() - ctx, cancel := context.WithCancel(t.Context()) // Create mock watcher with cache synced @@ -146,7 +110,7 @@ func TestReadinessEndpoint_DynamicMode_CacheSynced(t *testing.T) { Port: port, Watcher: mockWatcher, // Dynamic mode with synced cache SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewDynamicRegistry([]vmcp.Backend{}), nil) + }, rt, mockBackendClient, vmcp.NewDynamicRegistry([]vmcp.Backend{}), nil) require.NoError(t, err) t.Cleanup(cancel) @@ -187,28 +151,11 @@ func TestReadinessEndpoint_DynamicMode_CacheNotSynced(t *testing.T) { t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) port := networking.FindAvailable() require.NotZero(t, port, "Failed to find available port") - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{}, - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, - RoutingTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{}, - }, nil). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() - ctx, cancel := context.WithCancel(t.Context()) // Create mock watcher with cache NOT synced @@ -222,7 +169,7 @@ func TestReadinessEndpoint_DynamicMode_CacheNotSynced(t *testing.T) { Port: port, Watcher: mockWatcher, // Dynamic mode with unsynced cache SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewDynamicRegistry([]vmcp.Backend{}), nil) + }, rt, mockBackendClient, vmcp.NewDynamicRegistry([]vmcp.Backend{}), nil) require.NoError(t, err) t.Cleanup(cancel) diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 392db84541..48dde86070 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -40,7 +40,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/composer" vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/vmcp/core" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/headerforward" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" @@ -247,20 +246,17 @@ type Config struct { type Server struct { config *Config - // core is the domain VMCP, set only on the Serve path (nil on the legacy - // server.New path). When non-nil it is the single source of truth for the - // advertised capability set and call routing: session registration sources + // core is the domain VMCP and the single source of truth for the advertised + // capability set and call routing: session registration sources // tools/resources from core.ListTools/ListResources, request handlers - // delegate to core.CallTool/ReadResource, and the discovery middleware + - // context-based audit enrichment are guarded off (the core applies the - // admission filter the legacy authz/discovery path applied). The nil/non-nil - // value is the branch selector throughout this file ("s.core == nil" == legacy). + // delegate to core.CallTool/ReadResource, and authorization is enforced by + // the core admission seam. Serve always sets it, so it is non-nil for every + // server (BackendHealth keeps a defensive nil guard regardless). core core.VMCP // optimizerFactory builds a per-session optimizer over the core's advertised - // tools. Set only on the Serve path when the optimizer is enabled (nil otherwise, - // including the entire legacy server.New path, which decorates the session factory - // instead). When non-nil, Serve-path session registration advertises find_tool/ + // tools. Set only when the optimizer is enabled (nil otherwise). When non-nil, + // session registration advertises find_tool/ // call_tool in place of the raw core tools and dispatches call_tool's inner // invocation through core.CallTool. The shared store and cleanup are owned by the // session manager; this is the resolved factory surfaced via Manager.OptimizerFactory. @@ -276,11 +272,6 @@ type Server struct { listener net.Listener listenerMu sync.RWMutex - // Discovery manager — retained only so Stop() drives its (now no-op) cleanup. The core - // replaced discovery's aggregation on the New/Serve path, so New no longer wires the - // router / backend client / handler factory / capability adapter the legacy body held. - discoveryMgr discovery.Manager - // Backend registry for capability discovery // For static mode (CLI), this is an immutable registry created from initial backends. // For dynamic mode (K8s), this is a DynamicRegistry updated by the operator. @@ -358,16 +349,17 @@ func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession // It is the composition root for the in-memory Config form: it builds the backend // health monitor (A2), assembles the core via core.New using the config projected by // deriveCoreConfig, and hands that core to Serve with the transport config projected by -// deriveServerConfig. The 7-param signature is retained unchanged for existing callers -// (cli/serve.go and external embedders); the transport/core wiring it once performed -// inline now lives behind core.New + Serve. +// deriveServerConfig. The transport/core wiring it once performed inline now lives +// behind core.New + Serve. // // The backendRegistry parameter provides the list of available backends: // - For static mode (CLI), pass an immutable registry created from initial backends // - For dynamic mode (K8s), pass a DynamicRegistry that will be updated by the operator // -// Runtime-contract change for embedders (the signature is unchanged, but the body now -// routes through core.New + Serve): +// Signature/contract changes for embedders (the body now routes through core.New + Serve): +// - The discovery.Manager parameter was dropped when the discovery middleware was +// removed; capability discovery is the core's responsibility. Existing callers must +// drop that argument. // - Config.Aggregator is now REQUIRED (core.New rejects a nil aggregator); the core is // the single source of the advertised capability set. // - Config.AuthzMiddleware is vestigial: the Serve path never applies it. Authorization @@ -380,7 +372,6 @@ func New( cfg *Config, rt router.Router, backendClient vmcp.BackendClient, - discoveryMgr discovery.Manager, backendRegistry vmcp.BackendRegistry, workflowDefs map[string]*composer.WorkflowDefinition, ) (*Server, error) { @@ -480,12 +471,6 @@ func New( return nil, err } - // Retain the discovery manager only so Stop() drives its (now no-op) cleanup for - // parity with the legacy path. The core replaced discovery's aggregation, so it is - // otherwise unused here: the shared Handler guards the discovery middleware to - // s.core == nil, which never holds for a Serve-built server. - srv.discoveryMgr = discoveryMgr - // Bind the elicitation adapter to the SDK server Serve built so composite-workflow // elicitation reaches the same mcp-go server that serves client traffic. elicitation.bind(NewSDKElicitationAdapter(srv.MCPServer())) @@ -498,8 +483,8 @@ func New( // This enables embedding the vmcp server inside another HTTP server or framework. // // The returned handler includes all routes (health, metrics, well-known, MCP) -// and the full middleware chain (recovery, header validation, auth, audit, -// discovery, backend enrichment, MCP parsing, telemetry). +// and the full middleware chain (recovery, body limit, header validation, auth, +// rate limit, audit, MCP parsing, telemetry). // // Each call builds a fresh handler. The method is safe to call multiple times. // All returned handlers share the same underlying MCPServer and SessionManager, @@ -551,16 +536,14 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { } // MCP endpoint - apply middleware chain (wrapping order, execution happens in reverse): - // Code wraps: auth+parser → rate-limit → audit → discovery → backend-enrichment → - // MCP-parsing → telemetry - // Execution order: recovery → header-val → auth+parser → rate-limit → audit → - // discovery → backend-enrichment → MCP-parsing → telemetry → handler + // Code wraps: auth → rate-limit → audit → MCP-parsing → telemetry + // Execution order: recovery → body-limit → header-val → auth → rate-limit → + // audit → MCP-parsing → telemetry → handler // - // The legacy HTTP authz and annotation-enrichment layers have been removed: every caller - // now routes through Serve, so authorization is enforced by the core admission seam - // (#5438) rather than HTTP middleware. The remaining legacy-only blocks (backend - // enrichment, discovery) are guarded by s.core == nil and are removed in the 302b - // follow-up. + // The legacy HTTP authz, annotation-enrichment, and discovery layers have all been + // removed: every caller now routes through Serve, so authorization is enforced by the + // core admission seam (#5438) and capability/health filtering by the core, rather than + // by HTTP middleware. var mcpHandler http.Handler = streamableServer @@ -578,32 +561,6 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { // when auth middleware is nil. mcpHandler = mcpparser.ParsingMiddleware(mcpHandler) - // Apply discovery middleware (runs after audit/auth middleware) — legacy path only. - // Discovery middleware performs per-request capability aggregation with user context, - // injecting the routing table into the request context (the discovery-into-context seam). - // vmcpSessionMgr (MultiSessionGetter) is used to retrieve the fully-formed MultiSession - // for subsequent requests so the routing table can be injected into context. - // The backend registry provides a dynamic backend list (supports DynamicRegistry for K8s). - // The health monitor enables filtering based on current health status (respects circuit breaker). - // - // Guarded to the legacy server.New path (s.core == nil). On the Serve path the core is - // the single source of truth: session registration aggregates once via core.ListTools and - // handlers route through the core (#5442), so discovery is skipped — applying it would - // also nil-deref, since a Serve-built server has a nil discoveryMgr. WithSessionScopedRouting's - // initialize-skip behavior is preserved here on the legacy path; physical removal of the - // middleware and the context seam is deferred to Phase 3 (#5445). - if s.core == nil { - // Dead on the Serve path: s.core is never nil after #5445 routes New through Serve, - // and the health monitor now lives in the core, so no StatusProvider is available - // here. Kept (with nil health) only until 302b removes the discovery middleware and - // its context seam (anti-pattern #1). - mcpHandler = discovery.Middleware( - s.discoveryMgr, s.backendRegistry, s.vmcpSessionMgr, nil, - discovery.WithSessionScopedRouting(), - )(mcpHandler) - slog.Info("discovery middleware enabled for lazy per-user capability discovery") - } - // Apply audit middleware if configured (runs after auth, before discovery) if s.config.AuditConfig != nil { if err := s.config.AuditConfig.Validate(); err != nil { @@ -817,11 +774,6 @@ func (s *Server) Stop(ctx context.Context) error { } } - // Stop discovery manager to clean up background goroutines - if s.discoveryMgr != nil { - s.discoveryMgr.Stop() - } - // Close session data storage last: HTTP server is down (no new in-flight requests), // all other components have stopped (no further restore or liveness checks). if s.sessionDataStorage != nil { @@ -1107,7 +1059,7 @@ func (s *Server) handleSessionRegistration( // 3. Registers backend tools, composite tools, and resources with the SDK for the session. // // Tool and resource calls are routed directly through the session's backend connections -// rather than through the global router and discovery middleware. +// rather than through a global router. // Composite tool executors use the shared backend client and router. // // # Current capability surface diff --git a/pkg/vmcp/server/server_test.go b/pkg/vmcp/server/server_test.go index e3497accfa..fe78b94edb 100644 --- a/pkg/vmcp/server/server_test.go +++ b/pkg/vmcp/server/server_test.go @@ -21,7 +21,6 @@ import ( "github.com/stacklok/toolhive/pkg/authz/authorizers/cedar" mcpparser "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/vmcp" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" @@ -65,7 +64,6 @@ func TestServerStartFailsWhenReporterStartFails(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) srv, err := server.New( @@ -73,7 +71,6 @@ func TestServerStartFailsWhenReporterStartFails(t *testing.T) { &server.Config{Host: "127.0.0.1", Port: 0, StatusReporter: sr, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -94,16 +91,13 @@ func TestServerStopRunsReporterShutdown(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) - mockDiscoveryMgr.EXPECT().Stop().Times(1) srv, err := server.New( context.Background(), &server.Config{Host: "127.0.0.1", Port: 0, StatusReporter: sr, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -202,9 +196,8 @@ func TestNew(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - s, err := server.New(context.Background(), tt.config, mockRouter, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) + s, err := server.New(context.Background(), tt.config, mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) require.NoError(t, err) require.NotNil(t, s) @@ -339,9 +332,8 @@ func TestServer_Address(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - s, err := server.New(context.Background(), tt.config, mockRouter, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) + s, err := server.New(context.Background(), tt.config, mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) require.NoError(t, err) addr := s.Address() assert.Equal(t, tt.expected, addr) @@ -360,10 +352,8 @@ func TestServer_Stop(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - mockDiscoveryMgr.EXPECT().Stop().Times(1) - s, err := server.New(context.Background(), &server.Config{SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) + s, err := server.New(context.Background(), &server.Config{SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) require.NoError(t, err) err = s.Stop(context.Background()) require.NoError(t, err) @@ -378,7 +368,6 @@ func TestNew_NilSessionFactory_ReturnsError(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) _, err := server.New( context.Background(), @@ -386,7 +375,7 @@ func TestNew_NilSessionFactory_ReturnsError(t *testing.T) { SessionFactory: nil, // deliberately omitted Aggregator: newStubAggregator(nil), }, - mockRouter, mockBackendClient, mockDiscoveryMgr, + mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil, ) require.Error(t, err) @@ -404,7 +393,6 @@ func TestNew_NilAggregator_ReturnsError(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) _, err := server.New( context.Background(), @@ -412,7 +400,7 @@ func TestNew_NilAggregator_ReturnsError(t *testing.T) { SessionFactory: newNoopMockFactory(t), Aggregator: nil, // deliberately omitted: now a required field }, - mockRouter, mockBackendClient, mockDiscoveryMgr, + mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil, ) require.Error(t, err) @@ -488,14 +476,13 @@ func TestNew_WithAuditConfig(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) config := &server.Config{ AuditConfig: tt.auditConfig, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), } - s, err := server.New(context.Background(), config, mockRouter, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) + s, err := server.New(context.Background(), config, mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil) if tt.wantErr { require.Error(t, err) @@ -518,17 +505,13 @@ func TestServerStopClosesOptimizerStore(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) - mockDiscoveryMgr.EXPECT().Stop().Times(1) - srv, err := server.New( context.Background(), &server.Config{Host: "127.0.0.1", Port: 0, OptimizerConfig: &optimizer.Config{}, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -569,19 +552,16 @@ func TestHandler_ReturnsNonNilHandler(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) // Allow discovery middleware calls mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() srv, err := server.New( t.Context(), &server.Config{Host: "127.0.0.1", Port: 0, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -607,7 +587,6 @@ func TestHandler_ReturnsErrorOnInvalidAuditConfig(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) // AuditConfig with negative MaxDataSize fails validation inside Handler() @@ -624,7 +603,6 @@ func TestHandler_ReturnsErrorOnInvalidAuditConfig(t *testing.T) { }, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -648,18 +626,15 @@ func TestHandler_CanBeCalledMultipleTimes(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() srv, err := server.New( t.Context(), &server.Config{Host: "127.0.0.1", Port: 0, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -690,11 +665,9 @@ func TestHandler_RegistersWellKnownRoutes(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() // Stub AuthInfoHandler that responds with a fixed JSON body. authInfoHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -718,7 +691,6 @@ func TestHandler_RegistersWellKnownRoutes(t *testing.T) { }, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -815,18 +787,15 @@ func TestAcceptHeaderValidation(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() srv, err := server.New( t.Context(), &server.Config{Host: "127.0.0.1", Port: 0, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil)}, mockRouter, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) @@ -914,7 +883,6 @@ func TestNew_AuthzMiddlewareWithoutAuthz_ReturnsError(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) _, err := server.New(t.Context(), &server.Config{ @@ -922,7 +890,7 @@ func TestNew_AuthzMiddlewareWithoutAuthz_ReturnsError(t *testing.T) { Aggregator: newStubAggregator(nil), AuthzMiddleware: func(h http.Handler) http.Handler { return h }, // set without Authz }, - mockRouter, mockBackendClient, mockDiscoveryMgr, + mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil, ) require.Error(t, err) @@ -940,7 +908,6 @@ func TestNew_AuthzWithOptimizer_ReturnsError(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) _, err := server.New(t.Context(), &server.Config{ @@ -950,7 +917,7 @@ func TestNew_AuthzWithOptimizer_ReturnsError(t *testing.T) { Authz: newTestAuthzConfig(t), OptimizerConfig: &optimizer.Config{}, }, - mockRouter, mockBackendClient, mockDiscoveryMgr, + mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil, ) require.Error(t, err) @@ -969,7 +936,6 @@ func TestNew_AuthzWithoutName_ReturnsError(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) _, err := server.New(t.Context(), &server.Config{ @@ -978,7 +944,7 @@ func TestNew_AuthzWithoutName_ReturnsError(t *testing.T) { Aggregator: newStubAggregator(nil), Authz: newTestAuthzConfig(t), }, - mockRouter, mockBackendClient, mockDiscoveryMgr, + mockRouter, mockBackendClient, vmcp.NewImmutableRegistry([]vmcp.Backend{}), nil, ) require.Error(t, err) @@ -1005,11 +971,8 @@ func TestNewIgnoresVestigialAuthzMiddleware(t *testing.T) { t.Cleanup(ctrl.Finish) mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() // If this authz middleware is ever applied it records the fact and short-circuits // with the sentinel status. @@ -1033,7 +996,7 @@ func TestNewIgnoresVestigialAuthzMiddleware(t *testing.T) { AuthzMiddleware: authz, Authz: newTestAuthzConfig(t), } - srv, err := server.New(t.Context(), cfg, mockRouter, mockBackendClient, mockDiscoveryMgr, mockBackendRegistry, nil) + srv, err := server.New(t.Context(), cfg, mockRouter, mockBackendClient, mockBackendRegistry, nil) require.NoError(t, err) t.Cleanup(func() { _ = srv.Stop(context.Background()) }) diff --git a/pkg/vmcp/server/session_management_integration_test.go b/pkg/vmcp/server/session_management_integration_test.go index 836294078c..96f247f652 100644 --- a/pkg/vmcp/server/session_management_integration_test.go +++ b/pkg/vmcp/server/session_management_integration_test.go @@ -26,7 +26,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" "github.com/stacklok/toolhive/pkg/vmcp/composer" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" "github.com/stacklok/toolhive/pkg/vmcp/router" @@ -204,17 +203,11 @@ func buildTestServerWithOptions( t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) - // List() is consumed by the core's on-demand aggregation; the discovery middleware - // is guarded off on the Serve path (s.core != nil), so Discover() is no longer called - // (the AnyTimes expectations tolerate zero calls). Return an empty (non-nil) result. - emptyAggCaps := &aggregator.AggregatedCapabilities{} + // List() is consumed by the core's on-demand aggregation. mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(emptyAggCaps, nil).AnyTimes() // Stop is called when the server is stopped (not via httptest but via session manager cleanup). - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() // tools/call routes through core.CallTool → BackendClient.CallTool (the session factory's // own CallTool is bypassed on the Serve path). Return a deterministic result so call // tests can assert on it. @@ -223,7 +216,7 @@ func buildTestServerWithOptions( Return(&vmcp.ToolCallResult{Content: []vmcp.Content{{Type: "text", Text: "fake result"}}}, nil). AnyTimes() - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) srv, err := server.New( context.Background(), @@ -237,7 +230,6 @@ func buildTestServerWithOptions( }, rt, mockBackendClient, - mockDiscoveryMgr, mockBackendRegistry, opts.workflowDefs, ) diff --git a/pkg/vmcp/server/session_management_realbackend_integration_test.go b/pkg/vmcp/server/session_management_realbackend_integration_test.go index c5ec5b9afd..f05a7907fa 100644 --- a/pkg/vmcp/server/session_management_realbackend_integration_test.go +++ b/pkg/vmcp/server/session_management_realbackend_integration_test.go @@ -23,7 +23,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server" @@ -45,7 +44,6 @@ func newRealTestHandler(t *testing.T, backendURL string) http.Handler { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) backend := vmcp.Backend{ @@ -61,9 +59,6 @@ func newRealTestHandler(t *testing.T, backendURL string) http.Handler { // off on the Serve path); the AnyTimes expectation tolerates zero calls. mockBackendRegistry.EXPECT().List(gomock.Any()).Return([]vmcp.Backend{backend}).AnyTimes() mockBackendRegistry.EXPECT().Get(gomock.Any(), gomock.Any()).Return(&backend).AnyTimes() - mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{}, nil).AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() authReg := vmcpauth.NewDefaultOutgoingAuthRegistry() require.NoError(t, authReg.RegisterStrategy( @@ -82,7 +77,7 @@ func newRealTestHandler(t *testing.T, backendURL string) http.Handler { require.NoError(t, err) agg := aggregator.NewDefaultAggregator(backendClient, resolver, nil, nil) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) srv, err := server.New( context.Background(), &server.Config{ @@ -94,7 +89,6 @@ func newRealTestHandler(t *testing.T, backendURL string) http.Handler { }, rt, backendClient, - mockDiscoveryMgr, mockBackendRegistry, nil, ) diff --git a/pkg/vmcp/server/status_test.go b/pkg/vmcp/server/status_test.go index 92db210ba3..4a14d655ce 100644 --- a/pkg/vmcp/server/status_test.go +++ b/pkg/vmcp/server/status_test.go @@ -17,9 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/aggregator" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" @@ -212,8 +210,7 @@ func createTestServerWithHealthMonitor( t.Cleanup(ctrl.Finish) mockBackendClient := mocks.NewMockBackendClient(ctrl) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) if setupMock != nil { setupMock(mockBackendClient) @@ -222,22 +219,6 @@ func createTestServerWithHealthMonitor( port := networking.FindAvailable() require.NotZero(t, port, "Failed to find available port") - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Return(&aggregator.AggregatedCapabilities{ - Tools: []vmcp.Tool{}, - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, - RoutingTable: &vmcp.RoutingTable{ - Tools: make(map[string]*vmcp.BackendTarget), - Resources: make(map[string]*vmcp.BackendTarget), - Prompts: make(map[string]*vmcp.BackendTarget), - }, - Metadata: &aggregator.AggregationMetadata{}, - }, nil). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() - ctx, cancel := context.WithCancel(t.Context()) var healthMonCfg *health.MonitorConfig @@ -252,7 +233,7 @@ func createTestServerWithHealthMonitor( GroupRef: groupRef, HealthMonitorConfig: healthMonCfg, SessionFactory: newNoopMockFactory(t), Aggregator: newStubAggregator(nil), - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + }, rt, mockBackendClient, vmcp.NewImmutableRegistry(backends), nil) require.NoError(t, err) type startResult struct { diff --git a/pkg/vmcp/server/telemetry_integration_test.go b/pkg/vmcp/server/telemetry_integration_test.go index 30b00263ef..62ad095cbb 100644 --- a/pkg/vmcp/server/telemetry_integration_test.go +++ b/pkg/vmcp/server/telemetry_integration_test.go @@ -22,7 +22,6 @@ import ( transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -192,19 +191,9 @@ func TestIntegration_TelemetryMiddleware(t *testing.T) { } // Create discovery manager (follows same pattern as TestIntegration_AuditLogging) - mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) - mockDiscoveryMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { - resolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) - return agg.AggregateCapabilities(ctx, backends) - }). - AnyTimes() - mockDiscoveryMgr.EXPECT().Stop().AnyTimes() // Create router - rt := router.NewDefaultRouter() + rt := router.NewSessionRouter(&vmcp.RoutingTable{}) // Build the tools and routing table. The aggregator prefixes tool names with // "{workload}_", so "search" becomes "search-svc_search". @@ -246,7 +235,7 @@ func TestIntegration_TelemetryMiddleware(t *testing.T) { TelemetryProvider: telemetryProvider, SessionFactory: telemetryFactory, Aggregator: telemetryAgg, - }, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + }, rt, mockBackendClient, vmcp.NewImmutableRegistry(backends), nil) require.NoError(t, err) // Start server diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go index 7086ac9f90..0d1455fc9c 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go @@ -340,25 +340,25 @@ var _ = Describe("VirtualMCPServer Circuit Breaker Lifecycle", Ordered, func() { return nil }, timeout, pollingInterval).Should(Succeed()) - By("Note: Tools from unhealthy backends excluded by discovery middleware") + By("Note: Tools from unhealthy backends excluded during core capability aggregation") // NOTE: This e2e test verifies the circuit breaker state changes (above assertions). - // The capability filtering itself is thoroughly unit tested in the discovery middleware. + // The capability filtering itself is thoroughly unit tested in the vMCP core. // // Full end-to-end verification of tools/list filtering would require: // 1. Making an HTTP request to the vMCP server // 2. Implementing MCP protocol initialize handshake // 3. Calling tools/list and parsing the response // - // The filtering logic is implemented in pkg/vmcp/discovery/middleware.go:filterHealthyBackends() - // and covered by unit tests in middleware_test.go (TestFilterHealthyBackends, - // TestFilterHealthyBackends_WithHealthMonitor). + // The filtering logic is implemented in pkg/vmcp/core/core_vmcp.go:filterHealthyBackends() + // and covered by unit tests in core_vmcp_test.go (TestFilterHealthyBackends, + // TestFilterHealthyBackends_Empty). // // How it works: // - When backend circuit breaker opens → health monitor marks backend unhealthy - // - Discovery middleware queries health monitor via StatusProvider interface - // - handleInitializeRequest filters unhealthy backends before aggregation + // - The core queries the health monitor via the StatusProvider interface + // - filterHealthyBackends excludes unhealthy backends before capability aggregation // - Only healthy/degraded backends' tools appear in tools/list response - GinkgoWriter.Printf("ℹ️ Backend health filtering is unit tested in pkg/vmcp/discovery/middleware_test.go\n") + GinkgoWriter.Printf("ℹ️ Backend health filtering is unit tested in pkg/vmcp/core/core_vmcp_test.go\n") GinkgoWriter.Printf(" Circuit breaker state verified above; capability filtering covered by unit tests\n") }) diff --git a/test/integration/vmcp/helpers/vmcp_server.go b/test/integration/vmcp/helpers/vmcp_server.go index a32c1d47b0..aecb2f4cad 100644 --- a/test/integration/vmcp/helpers/vmcp_server.go +++ b/test/integration/vmcp/helpers/vmcp_server.go @@ -163,7 +163,7 @@ func NewVMCPServer( } agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, nil, nil) - rtr := router.NewDefaultRouter() + rtr := router.NewSessionRouter(&vmcptypes.RoutingTable{}) backendRegistry := vmcptypes.NewImmutableRegistry(backends) // Build the core VMCP — the single authoritative aggregation on the Serve path.