diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go index 8680f5c3d..813712064 100644 --- a/pkg/registry/builder.go +++ b/pkg/registry/builder.go @@ -1,10 +1,15 @@ package registry import ( + "context" "sort" "strings" ) +// ToolFilter is a function that determines if a tool should be included. +// Returns true if the tool should be included, false to exclude it. +type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error) + // Builder builds a Registry with the specified configuration. // Use NewBuilder to create a builder, chain configuration methods, // then call Build() to create the final Registry. @@ -19,6 +24,7 @@ import ( // WithReadOnly(true). // WithToolsets([]string{"repos", "issues"}). // WithFeatureChecker(checker). +// WithFilter(myFilter). // Build() type Builder struct { tools []ServerTool @@ -32,6 +38,7 @@ type Builder struct { toolsetIDsIsNil bool // tracks if nil was passed (nil = defaults) additionalTools []string // raw input, processed at Build() featureChecker FeatureFlagChecker + filters []ToolFilter // filters to apply to all tools } // NewBuilder creates a new Builder. @@ -111,6 +118,15 @@ func (b *Builder) WithFeatureChecker(checker FeatureFlagChecker) *Builder { return b } +// WithFilter adds a filter function that will be applied to all tools. +// Multiple filters can be added and are evaluated in order. +// If any filter returns false or an error, the tool is excluded. +// Returns self for chaining. +func (b *Builder) WithFilter(filter ToolFilter) *Builder { + b.filters = append(b.filters, filter) + return b +} + // Build creates the final Registry with all configuration applied. // This processes toolset filtering, tool name resolution, and sets up // the registry for use. The returned Registry is ready for use with @@ -123,6 +139,7 @@ func (b *Builder) Build() *Registry { deprecatedAliases: b.deprecatedAliases, readOnly: b.readOnly, featureChecker: b.featureChecker, + filters: b.filters, } // Process toolsets and pre-compute metadata in a single pass diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go index 3b251e439..5bce63570 100644 --- a/pkg/registry/filters.go +++ b/pkg/registry/filters.go @@ -51,20 +51,48 @@ func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disable } // isToolEnabled checks if a specific tool is enabled based on current filters. +// Filter evaluation order: +// 1. Tool.Enabled (tool self-filtering) +// 2. FeatureFlagEnable/FeatureFlagDisable +// 3. Read-only filter +// 4. Builder filters (via WithFilter) +// 5. Toolset/additional tools func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { - // Check read-only filter first (applies to all tools) - if r.readOnly && !tool.IsReadOnly() { - return false + // 1. Check tool's own Enabled function first + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.Enabled check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } } - // Check feature flags + // 2. Check feature flags if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { return false } - // Check if tool is in additionalTools (bypasses toolset filter) + // 3. Check read-only filter (applies to all tools) + if r.readOnly && !tool.IsReadOnly() { + return false + } + // 4. Apply builder filters + for _, filter := range r.filters { + allowed, err := filter(ctx, tool) + if err != nil { + fmt.Fprintf(os.Stderr, "Builder filter error for tool %q: %v\n", tool.Tool.Name, err) + return false + } + if !allowed { + return false + } + } + // 5. Check if tool is in additionalTools (bypasses toolset filter) if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { return true } - // Check toolset filter + // 5. Check toolset filter if !r.isToolsetEnabled(tool.Toolset.ID) { return false } @@ -245,3 +273,17 @@ func (r *Registry) EnabledToolsetIDs() []ToolsetID { sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) return ids } + +// FilteredTools returns tools filtered by the Enabled function and builder filters. +// This provides an explicit API for accessing filtered tools, currently implemented +// as an alias for AvailableTools. +// +// The error return is currently always nil but is included for future extensibility. +// Library consumers (e.g., remote server implementations) may need to surface +// recoverable filter errors rather than silently logging them. Having the error +// return in the API now avoids breaking changes later. +// +// The context is used for Enabled function evaluation and builder filter checks. +func (r *Registry) FilteredTools(ctx context.Context) ([]ServerTool, error) { + return r.AvailableTools(ctx), nil +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 56b2fdee5..7376a4c9c 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -53,6 +53,9 @@ type Registry struct { // Takes context and flag name, returns (enabled, error). If error, log and treat as false. // If checker is nil, all flag checks return false. featureChecker FeatureFlagChecker + // filters are functions that will be applied to all tools during filtering. + // If any filter returns false or an error, the tool is excluded. + filters []ToolFilter // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets unrecognizedToolsets []string } @@ -107,6 +110,7 @@ func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { enabledToolsets: r.enabledToolsets, // shared, not modified additionalTools: r.additionalTools, // shared, not modified featureChecker: r.featureChecker, + filters: r.filters, // shared, not modified unrecognizedToolsets: r.unrecognizedToolsets, } diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index 4518ed8bf..44ed3d773 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -1174,3 +1174,472 @@ func TestServerToolHandlerPanicOnNil(t *testing.T) { tool.Handler(nil) } + +// Tests for Enabled function on ServerTool +func TestServerToolEnabled(t *testing.T) { + tests := []struct { + name string + enabledFunc func(ctx context.Context) (bool, error) + expectedCount int + expectInResult bool + }{ + { + name: "nil Enabled function - tool included", + enabledFunc: nil, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns true - tool included", + enabledFunc: func(_ context.Context) (bool, error) { + return true, nil + }, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns false - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, nil + }, + expectedCount: 0, + expectInResult: false, + }, + { + name: "Enabled returns error - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, fmt.Errorf("simulated error") + }, + expectedCount: 0, + expectInResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = tt.enabledFunc + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != tt.expectedCount { + t.Errorf("Expected %d tools, got %d", tt.expectedCount, len(available)) + } + + found := false + for _, t := range available { + if t.Tool.Name == "test_tool" { + found = true + break + } + } + if found != tt.expectInResult { + t.Errorf("Expected tool in result: %v, got: %v", tt.expectInResult, found) + } + }) + } +} + +func TestServerToolEnabledWithContext(t *testing.T) { + type contextKey string + const userKey contextKey = "user" + + // Tool that checks context for user + tool := mockTool("context_aware_tool", "toolset1", true) + tool.Enabled = func(ctx context.Context) (bool, error) { + user := ctx.Value(userKey) + return user != nil && user.(string) == "authorized", nil + } + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + + // Without user in context - tool should be excluded + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools without user, got %d", len(available)) + } + + // With authorized user - tool should be included + ctxWithUser := context.WithValue(context.Background(), userKey, "authorized") + availableWithUser := reg.AvailableTools(ctxWithUser) + if len(availableWithUser) != 1 { + t.Errorf("Expected 1 tool with authorized user, got %d", len(availableWithUser)) + } + + // With unauthorized user - tool should be excluded + ctxWithBadUser := context.WithValue(context.Background(), userKey, "unauthorized") + availableWithBadUser := reg.AvailableTools(ctxWithBadUser) + if len(availableWithBadUser) != 0 { + t.Errorf("Expected 0 tools with unauthorized user, got %d", len(availableWithBadUser)) + } +} + +// Tests for WithFilter builder method +func TestBuilderWithFilter(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + } + + // Filter that excludes tool2 + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter, got %d", len(available)) + } + + for _, tool := range available { + if tool.Tool.Name == "tool2" { + t.Error("tool2 should have been filtered out") + } + } +} + +func TestBuilderWithMultipleFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + mockTool("tool4", "toolset1", true), + } + + // First filter excludes tool2 + filter1 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + // Second filter excludes tool3 + filter2 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool3", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter1). + WithFilter(filter2). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after multiple filters, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["tool1"] || !toolNames["tool4"] { + t.Error("Expected tool1 and tool4 to be available") + } + if toolNames["tool2"] || toolNames["tool3"] { + t.Error("tool2 and tool3 should have been filtered out") + } +} + +func TestBuilderFilterError(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Filter that returns an error + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, fmt.Errorf("filter error") + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when filter returns error, got %d", len(available)) + } +} + +func TestBuilderFilterWithContext(t *testing.T) { + type contextKey string + const scopeKey contextKey = "scope" + + tools := []ServerTool{ + mockTool("public_tool", "toolset1", true), + mockTool("private_tool", "toolset1", true), + } + + // Filter that checks context for scope + filter := func(ctx context.Context, tool *ServerTool) (bool, error) { + scope := ctx.Value(scopeKey) + if scope == "public" && tool.Tool.Name == "private_tool" { + return false, nil + } + return true, nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + // With public scope - private_tool should be excluded + ctxPublic := context.WithValue(context.Background(), scopeKey, "public") + availablePublic := reg.AvailableTools(ctxPublic) + if len(availablePublic) != 1 { + t.Fatalf("Expected 1 tool with public scope, got %d", len(availablePublic)) + } + if availablePublic[0].Tool.Name != "public_tool" { + t.Error("Expected only public_tool to be available") + } + + // With private scope - both tools should be available + ctxPrivate := context.WithValue(context.Background(), scopeKey, "private") + availablePrivate := reg.AvailableTools(ctxPrivate) + if len(availablePrivate) != 2 { + t.Errorf("Expected 2 tools with private scope, got %d", len(availablePrivate)) + } +} + +// Tests for interaction between Enabled, feature flags, and filters +func TestEnabledAndFeatureFlagInteraction(t *testing.T) { + // Tool with both Enabled function and feature flag + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Feature flag not enabled - tool should be excluded despite Enabled returning true + reg1 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + available1 := reg1.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when feature flag is not enabled") + } + + // Feature flag enabled - tool should be included + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 1 { + t.Error("Tool should be included when both Enabled and feature flag pass") + } + + // Enabled returns false - tool should be excluded despite feature flag + tool.Enabled = func(_ context.Context) (bool, error) { + return false, nil + } + reg3 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available3 := reg3.AvailableTools(context.Background()) + if len(available3) != 0 { + t.Error("Tool should be excluded when Enabled returns false") + } +} + +func TestEnabledAndBuilderFilterInteraction(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Filter that excludes the tool + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Error("Tool should be excluded when filter returns false, despite Enabled returning true") + } +} + +func TestAllFiltersInteraction(t *testing.T) { + // Tool with Enabled, feature flag, and subject to builder filter + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return true, nil + } + + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + + // All conditions pass - tool should be included + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Error("Tool should be included when all filters pass") + } + + // Change filter to return false - tool should be excluded + filterFalse := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filterFalse). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 0 { + t.Error("Tool should be excluded when any filter fails") + } +} + +// Test FilteredTools method +func TestFilteredTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + } + + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name == "tool1", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + filtered, err := reg.FilteredTools(context.Background()) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + if len(filtered) != 1 { + t.Fatalf("Expected 1 filtered tool, got %d", len(filtered)) + } + + if filtered[0].Tool.Name != "tool1" { + t.Errorf("Expected tool1, got %s", filtered[0].Tool.Name) + } +} + +func TestFilteredToolsMatchesAvailableTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"toolset1"}). + WithReadOnly(true). + Build() + + ctx := context.Background() + filtered, err := reg.FilteredTools(ctx) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + available := reg.AvailableTools(ctx) + + // Both methods should return the same results + if len(filtered) != len(available) { + t.Errorf("FilteredTools and AvailableTools returned different counts: %d vs %d", + len(filtered), len(available)) + } + + for i := range filtered { + if filtered[i].Tool.Name != available[i].Tool.Name { + t.Errorf("Tool at index %d differs: FilteredTools=%s, AvailableTools=%s", + i, filtered[i].Tool.Name, available[i].Tool.Name) + } + } +} + +func TestFilteringOrder(t *testing.T) { + // Test that filters are applied in the correct order: + // 1. Tool.Enabled + // 2. Feature flags + // 3. Read-only + // 4. Builder filters + // 5. Toolset/additional tools + + callOrder := []string{} + + tool := mockToolWithFlags("test_tool", "toolset1", false, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "Enabled") + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + callOrder = append(callOrder, "Filter") + return true, nil + } + + checker := func(_ context.Context, _ string) (bool, error) { + callOrder = append(callOrder, "FeatureFlag") + return true, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithReadOnly(true). // This will exclude the tool (it's not read-only) + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + _ = reg.AvailableTools(context.Background()) + + // Expected order: Enabled, FeatureFlag, ReadOnly (stops here because it's write tool) + expectedOrder := []string{"Enabled", "FeatureFlag"} + if len(callOrder) != len(expectedOrder) { + t.Errorf("Expected %d checks, got %d: %v", len(expectedOrder), len(callOrder), callOrder) + } + + for i, expected := range expectedOrder { + if i >= len(callOrder) || callOrder[i] != expected { + t.Errorf("At position %d: expected %s, got %v", i, expected, callOrder) + } + } +} diff --git a/pkg/registry/server_tool.go b/pkg/registry/server_tool.go index af00a2bed..3145b693d 100644 --- a/pkg/registry/server_tool.go +++ b/pkg/registry/server_tool.go @@ -52,6 +52,13 @@ type ServerTool struct { // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool // to be omitted. Used to disable tools when a feature flag is on. FeatureFlagDisable string + + // Enabled is an optional function called at build/filter time to determine + // if this tool should be available. If nil, the tool is considered enabled + // (subject to FeatureFlagEnable/FeatureFlagDisable checks). + // The context carries request-scoped information for the consumer to use. + // Returns (enabled, error). On error, the tool should be treated as disabled. + Enabled func(ctx context.Context) (bool, error) } // IsReadOnly returns true if this tool is marked as read-only via annotations.