Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pkg/mcp/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package mcp

// RequestError is implemented by domain errors that should fail the MCP request
// instead of being converted to a successful tool result with IsError=true.
//
// The mcp-go tool-handler seam maps returned errors to JSON-RPC INTERNAL_ERROR,
// so this is a control-flow marker, not a custom JSON-RPC code hook.
type RequestError interface {
error
MCPRequestError()
}
34 changes: 34 additions & 0 deletions pkg/ratelimit/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package ratelimit

import (
"time"

thvmcp "github.com/stacklok/toolhive/pkg/mcp"
)

const (
// CodeRateLimited is the JSON-RPC error code for rate-limited requests.
// Per RFC THV-0057: implementation-defined code in the -32000 to -32099 range.
CodeRateLimited int64 = -32029

// MessageRateLimited is the error message for rate-limited requests.
MessageRateLimited = "Rate limit exceeded"
)

// RateLimitedError reports that a request exceeded its configured rate limit.
type RateLimitedError struct {
RetryAfter time.Duration
}

var _ thvmcp.RequestError = (*RateLimitedError)(nil)

func (*RateLimitedError) Error() string {
return MessageRateLimited
}

// MCPRequestError marks rate-limit denials as request-level failures rather
// than tool execution errors.
func (*RateLimitedError) MCPRequestError() {}
25 changes: 25 additions & 0 deletions pkg/ratelimit/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/redis/go-redis/v9"

v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/ratelimit/internal/bucket"
)

Expand All @@ -36,6 +37,30 @@ type Decision struct {
RetryAfter time.Duration
}

// Allow checks whether identity may call toolName through limiter.
func Allow(ctx context.Context, limiter Limiter, identity *auth.Identity, toolName string) error {
if limiter == nil {
return nil
}

// When no identity is present (unauthenticated), userID stays empty and
// per-user buckets are skipped — only shared limits apply. CEL validation
// ensures perUser rate limits require auth to be enabled.
var userID string
if identity != nil {
userID = identity.Subject
}

decision, err := limiter.Allow(ctx, toolName, userID)
if err != nil {
return err
}
if !decision.Allowed {
return &RateLimitedError{RetryAfter: decision.RetryAfter}
}
return nil
}

// NewLimiter constructs a Limiter from CRD configuration.
// Returns a no-op limiter (always allows) when crd is nil.
// namespace and name identify the MCP server for Redis key derivation.
Expand Down
63 changes: 63 additions & 0 deletions pkg/ratelimit/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package ratelimit

import (
"errors"
"testing"
"time"

Expand All @@ -14,6 +15,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
"github.com/stacklok/toolhive/pkg/auth"
)

func newTestClient(t *testing.T) (*redis.Client, *miniredis.Miniredis) {
Expand All @@ -38,6 +40,67 @@ func TestNewLimiter_NilCRDReturnsNoop(t *testing.T) {
assert.True(t, d.Allowed)
}

func TestAllowNilLimiterAllows(t *testing.T) {
t.Parallel()

err := Allow(t.Context(), nil, &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}}, "echo")

require.NoError(t, err)
}

func TestAllowPassesIdentitySubject(t *testing.T) {
t.Parallel()

limiter := &recordingLimiter{}
identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}}

err := Allow(t.Context(), limiter, identity, "echo")

require.NoError(t, err)
assert.Equal(t, "echo", limiter.toolName)
assert.Equal(t, "alice", limiter.userID)
}

func TestAllowNilIdentityUsesEmptyUserID(t *testing.T) {
t.Parallel()

limiter := &recordingLimiter{}

err := Allow(t.Context(), limiter, nil, "echo")

require.NoError(t, err)
assert.Equal(t, "echo", limiter.toolName)
assert.Empty(t, limiter.userID)
}

func TestAllowRateLimitedReturnsTypedError(t *testing.T) {
t.Parallel()

limiter := &dummyLimiter{
decision: &Decision{
Allowed: false,
RetryAfter: 3 * time.Second,
},
}

err := Allow(t.Context(), limiter, nil, "echo")

var limited *RateLimitedError
require.ErrorAs(t, err, &limited)
assert.Equal(t, 3*time.Second, limited.RetryAfter)
}

func TestAllowPropagatesLimiterError(t *testing.T) {
t.Parallel()

expected := errors.New("redis unavailable")
limiter := &dummyLimiter{err: expected}

err := Allow(t.Context(), limiter, nil, "echo")

require.ErrorIs(t, err, expected)
}

func TestNewLimiter_ZeroMaxTokens(t *testing.T) {
t.Parallel()
client, _ := newTestClient(t)
Expand Down
54 changes: 27 additions & 27 deletions pkg/ratelimit/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package ratelimit
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math"
"net/http"
Expand All @@ -25,13 +27,6 @@ const (
// MiddlewareType is the type constant for the rate limit middleware.
MiddlewareType = "ratelimit"

// CodeRateLimited is the JSON-RPC error code for rate-limited requests.
// Per RFC THV-0057: implementation-defined code in the -32000 to -32099 range.
CodeRateLimited int64 = -32029

// MessageRateLimited is the JSON-RPC error message for rate-limited requests.
MessageRateLimited = "Rate limit exceeded"

// redisPasswordEnvVar is the environment variable containing the Redis password.
// Shared with session storage — the operator injects it from the same Secret.
redisPasswordEnvVar = "THV_SESSION_REDIS_PASSWORD" //nolint:gosec // G101: env var name, not a credential
Expand All @@ -49,7 +44,7 @@ type MiddlewareParams struct {
// rateLimitMiddleware wraps rate limiting functionality for the factory pattern.
type rateLimitMiddleware struct {
handler types.MiddlewareFunction
client redis.UniversalClient
closer io.Closer
}

// Handler returns the middleware function used by the proxy.
Expand All @@ -59,16 +54,16 @@ func (m *rateLimitMiddleware) Handler() types.MiddlewareFunction {

// Close cleans up the Redis client.
func (m *rateLimitMiddleware) Close() error {
if m.client != nil {
return m.client.Close()
if m.closer != nil {
return m.closer.Close()
}
return nil
}

// NewMiddleware creates a Redis-backed rate limit middleware from typed params.
func NewMiddleware(params MiddlewareParams) (types.Middleware, error) {
// NewRedisLimiter creates a Redis-backed rate limiter from typed params.
func NewRedisLimiter(params MiddlewareParams) (Limiter, io.Closer, error) {
if params.RedisAddr == "" {
return nil, fmt.Errorf("rate limit middleware requires a Redis address")
return nil, nil, fmt.Errorf("rate limit middleware requires a Redis address")
}

// TODO: share a Redis client builder with session storage to get TLS,
Expand All @@ -84,18 +79,28 @@ func NewMiddleware(params MiddlewareParams) (types.Middleware, error) {
defer pingCancel()
if err := client.Ping(pingCtx).Err(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w", params.RedisAddr, err)
return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w", params.RedisAddr, err)
}

limiter, err := NewLimiter(client, params.Namespace, params.ServerName, params.Config)
if err != nil {
_ = client.Close()
return nil, fmt.Errorf("failed to create rate limiter: %w", err)
return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err)
}

return limiter, client, nil
}

// NewMiddleware creates a Redis-backed rate limit middleware from typed params.
func NewMiddleware(params MiddlewareParams) (types.Middleware, error) {
limiter, closer, err := NewRedisLimiter(params)
if err != nil {
return nil, err
}

return &rateLimitMiddleware{
handler: rateLimitHandler(limiter),
client: client,
closer: closer,
}, nil
}

Expand Down Expand Up @@ -128,23 +133,18 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
return
}

// When no identity is present (unauthenticated), userID stays empty
// and per-user buckets are skipped — only shared limits apply. CEL
// validation ensures perUser rate limits require auth to be enabled.
var userID string
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
userID = identity.Subject
identity, _ := auth.IdentityFromContext(r.Context())
err := Allow(r.Context(), limiter, identity, parsed.ResourceID)
var limited *RateLimitedError
if errors.As(err, &limited) {
writeRateLimited(w, parsed.ID, limited.RetryAfter)
return
}
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
if err != nil {
slog.Warn("rate limit check failed, allowing request", "error", err)
next.ServeHTTP(w, r)
return
}
if !decision.Allowed {
writeRateLimited(w, parsed.ID, decision.RetryAfter)
return
}
next.ServeHTTP(w, r)
})
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/vmcp/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
// aggregation (agg feeds it via Config.Aggregator below).
sessionFactory := vmcpsession.NewSessionFactory(outgoingRegistry)

// When the optimizer is enabled, its meta-tools must pass through the authz
// response filter so they appear in tools/list.
// When the optimizer is enabled, its meta-tools are pass-through tools.
// Authz uses this for optimizer-aware authorization/filtering.
var passThroughTools map[string]struct{}
if optCfg != nil {
passThroughTools = map[string]struct{}{
Expand Down Expand Up @@ -389,19 +389,19 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
slog.Info(fmt.Sprintf("Incoming authentication configured: %s", vmcpCfg.IncomingAuth.Type))

namespace := vmcpNamespace()
rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{
rateLimiter, rateLimitCleanup, err := ratelimitfactory.NewLimiter(ctx, ratelimitfactory.Config{
Namespace: namespace,
ServerName: vmcpCfg.Name,
RateLimiting: vmcpCfg.RateLimiting,
SessionStorage: vmcpCfg.SessionStorage,
})
if err != nil {
return fmt.Errorf("failed to create rate limit middleware: %w", err)
return fmt.Errorf("failed to create rate limiter: %w", err)
}
if rateLimitCleanup != nil {
defer func() {
if closeErr := rateLimitCleanup(context.Background()); closeErr != nil {
slog.Error(fmt.Sprintf("failed to close rate limit middleware: %v", closeErr))
slog.Error(fmt.Sprintf("failed to close rate limiter: %v", closeErr))
}
}()
}
Expand All @@ -422,7 +422,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
AuthzMiddleware: authzMiddleware,
AuthInfoHandler: authInfoHandler,
PassthroughHeaders: vmcpCfg.PassthroughHeaders,
RateLimitMiddleware: rateLimitMiddleware,
RateLimiter: rateLimiter,
AuthServer: embeddedAuthServer,
TelemetryProvider: telemetryProvider,
AuditConfig: vmcpCfg.Audit,
Expand Down
61 changes: 61 additions & 0 deletions pkg/vmcp/ratelimit/decorator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

// Package ratelimit applies rate limiting at the vMCP domain boundary.
package ratelimit

import (
"context"
"errors"
"log/slog"

"github.com/stacklok/toolhive/pkg/auth"
baseratelimit "github.com/stacklok/toolhive/pkg/ratelimit"
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/core"
)

// decorator wraps a [core.VMCP] to rate-limit tool calls. Every method except
// CallTool is promoted from the embedded inner core unchanged. The decorator sits
// below the session optimizer, so name is already the resolved backend tool name.
type decorator struct {
core.VMCP
limiter baseratelimit.Limiter
}

var _ core.VMCP = (*decorator)(nil)

// NewDecorator wraps inner with vMCP rate limiting.
//
// inner must be non-nil; a nil inner is a composition-root wiring bug and panics
// rather than deferring the failure to the first promoted method call. A nil
// limiter means rate limiting is disabled and inner is returned unchanged.
func NewDecorator(inner core.VMCP, limiter baseratelimit.Limiter) core.VMCP {
if inner == nil {
panic("ratelimit: NewDecorator requires a non-nil inner VMCP")
}
if limiter == nil {
return inner
}
return &decorator{
VMCP: inner,
limiter: limiter,
}
}

// CallTool checks the rate limit for name before delegating to inner. At this
// seam name is already resolved by any outer optimizer layer, so per-tool bucket
// keys match the real backend tool instead of the optimizer call_tool meta-tool.
func (d *decorator) CallTool(
ctx context.Context, identity *auth.Identity, name string,
args map[string]any, meta map[string]any,
) (*vmcp.ToolCallResult, error) {
if err := baseratelimit.Allow(ctx, d.limiter, identity, name); err != nil {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Allow returns three outcomes — nil, *RateLimitedError, or the raw limiter error (e.g. Redis unreachable). This returns the raw error too, so the decorator fails closed: a Redis blip fails every vMCP tool call. The HTTP middleware path consuming the same Allow fails open (slog.Warn("rate limit check failed, allowing request") → proceeds). Worth matching that posture here — block only on *RateLimitedError, log-and-allow on infra errors — so the two consumers of Allow behave consistently and a limiter outage doesn't take down tool calls. If fail-closed is actually desired for vMCP, a one-line comment saying so would settle it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The vMCP rate-limit decorator now matches the HTTP middleware posture: it blocks only on *RateLimitedError and logs/allows raw limiter infrastructure errors like Redis failures.

That keeps limiter outages from taking down all vMCP tool calls while still enforcing actual rate-limit denials.

var limited *baseratelimit.RateLimitedError
if errors.As(err, &limited) {
return nil, err
}
slog.WarnContext(ctx, "rate limit check failed, allowing tool call", "tool", name, "error", err)
}
return d.VMCP.CallTool(ctx, identity, name, args, meta)
}
Loading
Loading