From ff237aa73152341fa207e1782d21eae4b80cbe97 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 23:52:36 +0100 Subject: [PATCH 01/27] Add --features CLI flag for feature flag support Add CLI flag and config support for feature flags in the local server: - Add --features flag to main.go (StringSlice, comma-separated) - Add EnabledFeatures field to StdioServerConfig and MCPServerConfig - Create createFeatureChecker() that builds a set from enabled features - Wire WithFeatureChecker() into the toolset group filter chain This enables tools/resources/prompts that have FeatureFlagEnable set to a flag name that is passed via --features. The checker uses a simple set membership test for O(1) lookup. Usage: github-mcp-server stdio --features=my_feature,another_feature GITHUB_FEATURES=my_feature github-mcp-server stdio --- README.md | 7 +- cmd/github-mcp-server/generate_docs.go | 147 ++-- cmd/github-mcp-server/main.go | 10 +- docs/deprecated-tool-aliases.md | 31 + docs/remote-server.md | 1 - internal/ghmcp/server.go | 117 +-- pkg/github/actions.go | 14 + pkg/github/code_scanning.go | 2 + pkg/github/context_tools.go | 3 + pkg/github/dependabot.go | 2 + pkg/github/dependencies.go | 12 +- pkg/github/discussions.go | 4 + pkg/github/dynamic_tools.go | 294 ++++--- pkg/github/gists.go | 4 + pkg/github/git.go | 1 + pkg/github/issues.go | 20 +- pkg/github/labels.go | 3 + pkg/github/notifications.go | 6 + pkg/github/projects.go | 9 + pkg/github/prompts.go | 16 + pkg/github/pullrequests.go | 10 + pkg/github/repositories.go | 18 + pkg/github/repository_resource.go | 56 +- pkg/github/repository_resource_test.go | 30 +- pkg/github/resources.go | 20 + pkg/github/search.go | 4 + pkg/github/secret_scanning.go | 2 + pkg/github/security_advisories.go | 4 + pkg/github/server.go | 14 +- pkg/github/tools.go | 501 ++++------- pkg/github/toolset_group.go | 20 + pkg/github/workflow_prompts.go | 10 +- pkg/toolsets/server_tool.go | 59 +- pkg/toolsets/toolsets.go | 801 ++++++++++++----- pkg/toolsets/toolsets_test.go | 1109 +++++++++++++++++++----- 35 files changed, 2290 insertions(+), 1071 deletions(-) create mode 100644 docs/deprecated-tool-aliases.md create mode 100644 pkg/github/prompts.go create mode 100644 pkg/github/resources.go create mode 100644 pkg/github/toolset_group.go diff --git a/README.md b/README.md index bcd9f85c8..117bacacd 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,7 @@ You can also configure specific tools using the `--tools` flag. Tools can be use - Tools, toolsets, and dynamic toolsets can all be used together - Read-only mode takes priority: write tools are skipped if `--read-only` is set, even if explicitly requested via `--tools` - Tool names must match exactly (e.g., `get_file_contents`, not `getFileContents`). Invalid tool names will cause the server to fail at startup with an error message +- When tools are renamed, old names are preserved as aliases for backward compatibility. See [Deprecated Tool Aliases](docs/deprecated-tool-aliases.md) for details. ### Using Toolsets With Docker @@ -459,7 +460,6 @@ The following sets of tools are available: | `code_security` | Code security related tools, such as GitHub Code Scanning | | `dependabot` | Dependabot tools | | `discussions` | GitHub Discussions related tools | -| `experiments` | Experimental features that are not considered stable yet | | `gists` | GitHub Gist related tools | | `git` | GitHub Git API related tools for low-level Git operations | | `issues` | GitHub Issues related tools | @@ -718,11 +718,6 @@ The following sets of tools are available: - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) -- **get_label** - Get a specific label from a repository. - - `name`: Label name. (string, required) - - `owner`: Repository owner (username or organization name) (string, required) - - `repo`: Repository name (string, required) - - **issue_read** - Get issue details - `issue_number`: The number of the issue (number, required) - `method`: The read operation to perform on a single issue. diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 61459d7f0..785bd3ff0 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -10,14 +10,12 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/shurcooL/githubv4" "github.com/spf13/cobra" ) @@ -39,11 +37,6 @@ func mockGetClient(_ context.Context) (*gogithub.Client, error) { return gogithub.NewClient(nil), nil } -// mockGetGQLClient returns a mock GraphQL client for documentation generation -func mockGetGQLClient(_ context.Context) (*githubv4.Client, error) { - return githubv4.NewClient(nil), nil -} - // mockGetRawClient returns a mock raw client for documentation generation func mockGetRawClient(_ context.Context) (*raw.Client, error) { return nil, nil @@ -58,6 +51,10 @@ func generateAllDocs() error { return fmt.Errorf("failed to generate remote-server docs: %w", err) } + if err := generateDeprecatedAliasesDocs("docs/deprecated-tool-aliases.md"); err != nil { + return fmt.Errorf("failed to generate deprecated aliases docs: %w", err) + } + return nil } @@ -65,9 +62,8 @@ func generateReadmeDocs(readmePath string) error { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients - repoAccessCache := lockdown.GetInstance(nil) - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) + // Create toolset group with mock clients (no deps needed for doc generation) + tsg := github.NewToolsetGroup(t, mockGetClient, mockGetRawClient) // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(tsg) @@ -133,20 +129,16 @@ func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { // Add the context toolset row (handled separately in README) lines = append(lines, "| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |") - // Get all toolsets except context (which is handled separately above) - var toolsetNames []string - for name := range tsg.Toolsets { - if name != "context" && name != "dynamic" { // Skip context and dynamic toolsets as they're handled separately - toolsetNames = append(toolsetNames, name) - } - } - - // Sort toolset names for consistent output - sort.Strings(toolsetNames) + // Get toolset IDs and descriptions + toolsetIDs := tsg.ToolsetIDs() + descriptions := tsg.ToolsetDescriptions() - for _, name := range toolsetNames { - toolset := tsg.Toolsets[name] - lines = append(lines, fmt.Sprintf("| `%s` | %s |", name, toolset.Description)) + // Filter out context and dynamic toolsets (handled separately) + for _, id := range toolsetIDs { + if id != "context" && id != "dynamic" { + description := descriptions[id] + lines = append(lines, fmt.Sprintf("| `%s` | %s |", id, description)) + } } return strings.Join(lines, "\n") @@ -155,30 +147,22 @@ func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { func generateToolsDoc(tsg *toolsets.ToolsetGroup) string { var sections []string - // Get all toolset names and sort them alphabetically for deterministic order - var toolsetNames []string - for name := range tsg.Toolsets { - if name != "dynamic" { // Skip dynamic toolset as it's handled separately - toolsetNames = append(toolsetNames, name) - } - } - sort.Strings(toolsetNames) + // Get toolset IDs (already sorted deterministically) + toolsetIDs := tsg.ToolsetIDs() - for _, toolsetName := range toolsetNames { - toolset := tsg.Toolsets[toolsetName] + for _, toolsetID := range toolsetIDs { + if toolsetID == "dynamic" { // Skip dynamic toolset as it's handled separately + continue + } - tools := toolset.GetAvailableTools() + // Get tools for this toolset (already sorted deterministically) + tools := tsg.ToolsForToolset(toolsetID) if len(tools) == 0 { continue } - // Sort tools by name for deterministic order - sort.Slice(tools, func(i, j int) bool { - return tools[i].Tool.Name < tools[j].Tool.Name - }) - // Generate section header - capitalize first letter and replace underscores - sectionName := formatToolsetName(toolsetName) + sectionName := formatToolsetName(string(toolsetID)) var toolDocs []string for _, serverTool := range tools { @@ -322,33 +306,30 @@ func generateRemoteToolsetsDoc() string { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - repoAccessCache := lockdown.GetInstance(nil) - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) + tsg := github.NewToolsetGroup(t, mockGetClient, mockGetRawClient) // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") buf.WriteString("|----------------|--------------------------------------------------|-------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n") - // Get all toolsets - toolsetNames := make([]string, 0, len(tsg.Toolsets)) - for name := range tsg.Toolsets { - if name != "context" && name != "dynamic" { // Skip context and dynamic toolsets as they're handled separately - toolsetNames = append(toolsetNames, name) - } - } - sort.Strings(toolsetNames) + // Get toolset IDs and descriptions + toolsetIDs := tsg.ToolsetIDs() + descriptions := tsg.ToolsetDescriptions() // Add "all" toolset first (special case) buf.WriteString("| all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) |\n") // Add individual toolsets - for _, name := range toolsetNames { - toolset := tsg.Toolsets[name] + for _, id := range toolsetIDs { + idStr := string(id) + if idStr == "context" || idStr == "dynamic" { // Skip context and dynamic toolsets as they're handled separately + continue + } - formattedName := formatToolsetName(name) - description := toolset.Description - apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", name) - readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", name) + description := descriptions[id] + formattedName := formatToolsetName(idStr) + apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) + readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) // Create install config JSON (URL encoded) installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) @@ -358,8 +339,8 @@ func generateRemoteToolsetsDoc() string { installConfig = strings.ReplaceAll(installConfig, "+", "%20") readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") - installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", name, installConfig) - readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", name, readonlyConfig) + installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) + readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) buf.WriteString(fmt.Sprintf("| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", formattedName, @@ -373,3 +354,53 @@ func generateRemoteToolsetsDoc() string { return buf.String() } + +func generateDeprecatedAliasesDocs(docsPath string) error { + // Read the current file + content, err := os.ReadFile(docsPath) //#nosec G304 + if err != nil { + return fmt.Errorf("failed to read docs file: %w", err) + } + + // Generate the table + aliasesDoc := generateDeprecatedAliasesTable() + + // Replace content between markers + updatedContent := replaceSection(string(content), "START AUTOMATED ALIASES", "END AUTOMATED ALIASES", aliasesDoc) + + // Write back to file + err = os.WriteFile(docsPath, []byte(updatedContent), 0600) + if err != nil { + return fmt.Errorf("failed to write deprecated aliases docs: %w", err) + } + + fmt.Println("Successfully updated docs/deprecated-tool-aliases.md with automated documentation") + return nil +} + +func generateDeprecatedAliasesTable() string { + var lines []string + + // Add table header + lines = append(lines, "| Old Name | New Name |") + lines = append(lines, "|----------|----------|") + + aliases := github.DeprecatedToolAliases + if len(aliases) == 0 { + lines = append(lines, "| *(none currently)* | |") + } else { + // Sort keys for deterministic output + var oldNames []string + for oldName := range aliases { + oldNames = append(oldNames, oldName) + } + sort.Strings(oldNames) + + for _, oldName := range oldNames { + newName := aliases[oldName] + lines = append(lines, fmt.Sprintf("| `%s` | `%s` |", oldName, newName)) + } + } + + return strings.Join(lines, "\n") +} diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 87eeedd2e..84c974dad 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -52,9 +52,10 @@ var ( return fmt.Errorf("failed to unmarshal tools: %w", err) } - // If neither toolset config nor tools config is passed we enable the default toolset - if len(enabledToolsets) == 0 && len(enabledTools) == 0 { - enabledToolsets = []string{github.ToolsetMetadataDefault.ID} + // Parse enabled features (similar to toolsets) + var enabledFeatures []string + if err := viper.UnmarshalKey("features", &enabledFeatures); err != nil { + return fmt.Errorf("failed to unmarshal features: %w", err) } ttl := viper.GetDuration("repo-access-cache-ttl") @@ -64,6 +65,7 @@ var ( Token: token, EnabledToolsets: enabledToolsets, EnabledTools: enabledTools, + EnabledFeatures: enabledFeatures, DynamicToolsets: viper.GetBool("dynamic_toolsets"), ReadOnly: viper.GetBool("read-only"), ExportTranslations: viper.GetBool("export-translations"), @@ -87,6 +89,7 @@ func init() { // Add global flags that will be shared by all commands rootCmd.PersistentFlags().StringSlice("toolsets", nil, github.GenerateToolsetsHelp()) rootCmd.PersistentFlags().StringSlice("tools", nil, "Comma-separated list of specific tools to enable") + rootCmd.PersistentFlags().StringSlice("features", nil, "Comma-separated list of feature flags to enable") rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().String("log-file", "", "Path to log file") @@ -100,6 +103,7 @@ func init() { // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) _ = viper.BindPFlag("tools", rootCmd.PersistentFlags().Lookup("tools")) + _ = viper.BindPFlag("features", rootCmd.PersistentFlags().Lookup("features")) _ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) diff --git a/docs/deprecated-tool-aliases.md b/docs/deprecated-tool-aliases.md new file mode 100644 index 000000000..6ea2ba1de --- /dev/null +++ b/docs/deprecated-tool-aliases.md @@ -0,0 +1,31 @@ +# Deprecated Tool Aliases + +This document tracks tool renames in the GitHub MCP Server. When tools are renamed, the old names are preserved as aliases for backward compatibility. Using a deprecated alias will still work, but clients should migrate to the new canonical name. + +## Current Deprecations + + +| Old Name | New Name | +|----------|----------| +| *(none currently)* | | + + +## How It Works + +When a tool is renamed: + +1. The old name is added to `DeprecatedToolAliases` in [pkg/github/deprecated_tool_aliases.go](../pkg/github/deprecated_tool_aliases.go) +2. Clients using the old name will receive the new tool +3. A deprecation notice is logged when the alias is used + +## For Developers + +To deprecate a tool name when renaming: + +```go +var DeprecatedToolAliases = map[string]string{ + "old_tool_name": "new_tool_name", +} +``` + +The alias resolution happens at server startup, ensuring backward compatibility for existing client configurations. diff --git a/docs/remote-server.md b/docs/remote-server.md index e06d41a75..ffdf526a4 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -24,7 +24,6 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Code Security | Code security related tools, such as GitHub Code Scanning | https://api.githubcopilot.com/mcp/x/code_security | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/code_security/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%2Freadonly%22%7D) | | Dependabot | Dependabot tools | https://api.githubcopilot.com/mcp/x/dependabot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/dependabot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%2Freadonly%22%7D) | | Discussions | GitHub Discussions related tools | https://api.githubcopilot.com/mcp/x/discussions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/discussions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%2Freadonly%22%7D) | -| Experiments | Experimental features that are not considered stable yet | https://api.githubcopilot.com/mcp/x/experiments | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/experiments/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%2Freadonly%22%7D) | | Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | | Git | GitHub Git API related tools for low-level Git operations | https://api.githubcopilot.com/mcp/x/git | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/git/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%2Freadonly%22%7D) | | Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index c0f4e25e7..0edca88ed 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -18,6 +18,7 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -42,6 +43,10 @@ type MCPServerConfig struct { // When specified, these tools are registered in addition to any specified toolset tools EnabledTools []string + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + // Whether to enable dynamic toolsets // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery DynamicToolsets bool @@ -100,24 +105,9 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { enabledToolsets := cfg.EnabledToolsets - // If dynamic toolsets are enabled, remove "all" and "default" from the enabled toolsets - if cfg.DynamicToolsets { - enabledToolsets = github.RemoveToolset(enabledToolsets, github.ToolsetMetadataAll.ID) - enabledToolsets = github.RemoveToolset(enabledToolsets, github.ToolsetMetadataDefault.ID) - } - - // Clean up the passed toolsets + // Clean up the passed toolsets (removes duplicates, whitespace) enabledToolsets, invalidToolsets := github.CleanToolsets(enabledToolsets) - // If "all" is present, override all other toolsets - if github.ContainsToolset(enabledToolsets, github.ToolsetMetadataAll.ID) { - enabledToolsets = []string{github.ToolsetMetadataAll.ID} - } - // If "default" is present, expand to real toolset IDs - if github.ContainsToolset(enabledToolsets, github.ToolsetMetadataDefault.ID) { - enabledToolsets = github.AddDefaultToolset(enabledToolsets) - } - if len(invalidToolsets) > 0 { fmt.Fprintf(os.Stderr, "Invalid toolsets ignored: %s\n", strings.Join(invalidToolsets, ", ")) } @@ -162,51 +152,73 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { ContentWindowSize: cfg.ContentWindowSize, } - // Create default toolsets - tsg := github.DefaultToolsetGroup( - cfg.ReadOnly, - getClient, - getGQLClient, - getRawClient, - cfg.Translator, - cfg.ContentWindowSize, - github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - repoAccessCache, - ) - - // Enable and register toolsets if configured - // This always happens if toolsets are specified, regardless of whether tools are also specified - if len(enabledToolsets) > 0 { - err = tsg.EnableToolsets(enabledToolsets, nil) - if err != nil { - return nil, fmt.Errorf("failed to enable toolsets: %w", err) - } + // Create toolset group with all tools, resources, and prompts + tsg := github.NewToolsetGroup(cfg.Translator, getClient, getRawClient) - // Register all mcp functionality with the server - tsg.RegisterAll(ghServer) - } + // Add deprecated tool aliases for backward compatibility + // See docs/deprecated-tool-aliases.md for the full list of renames + tsg.AddDeprecatedToolAliases(github.DeprecatedToolAliases) - // Register specific tools if configured - if len(cfg.EnabledTools) > 0 { - enabledTools := github.CleanTools(cfg.EnabledTools) - enabledTools, _ = tsg.ResolveToolAliases(enabledTools) + // Clean tool names (WithTools will resolve any deprecated aliases) + enabledTools := github.CleanTools(cfg.EnabledTools) - // Register the specified tools (additive to any toolsets already enabled) - err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly, deps) - if err != nil { - return nil, fmt.Errorf("failed to register tools: %w", err) - } + // For dynamic toolsets mode: + // - If toolsets are explicitly provided (including "default"), honor them + // - If no toolsets are specified (nil), start with no toolsets enabled (empty slice) + // so users can enable them on demand via the dynamic tools + if cfg.DynamicToolsets && cfg.EnabledToolsets == nil { + enabledToolsets = []string{} } - // Register dynamic toolsets if configured (additive to toolsets and tools) + // Apply filters based on configuration + // - WithReadOnly: filters out write tools when true + // - WithToolsets: nil=defaults, empty=none, handles "all"/"default" keywords + // - WithTools: additional tools that bypass toolset filtering (additive, resolves aliases) + // - WithFeatureChecker: filters based on feature flags + filteredTsg := tsg. + WithReadOnly(cfg.ReadOnly). + WithToolsets(enabledToolsets). + WithTools(enabledTools). + WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) + + // Register all mcp functionality with the server + // Use background context for local server (no per-request actor context) + filteredTsg.RegisterAll(context.Background(), ghServer, deps) + + // Register dynamic toolset management if configured + // Dynamic tools get access to the filtered toolset group which tracks enabled state. + // ToolsForToolset() returns all tools for a toolset regardless of enabled status, + // so dynamic tools can enable any toolset at runtime. if cfg.DynamicToolsets { - dynamic := github.InitDynamicToolset(ghServer, tsg, cfg.Translator) - dynamic.RegisterTools(ghServer) + dynamicDeps := github.DynamicToolDependencies{ + Server: ghServer, + ToolsetGroup: filteredTsg, + ToolDeps: deps, + T: cfg.Translator, + } + dynamicTools := github.DynamicTools() + for _, tool := range dynamicTools { + tool.RegisterFunc(ghServer, dynamicDeps) + } } return ghServer, nil } +// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name +// is present in the provided list of enabled features. For the local server, +// this is populated from the --features CLI flag. +func createFeatureChecker(enabledFeatures []string) toolsets.FeatureFlagChecker { + // Build a set for O(1) lookup + featureSet := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + featureSet[f] = true + } + return func(_ context.Context, flagName string) (bool, error) { + return featureSet[flagName], nil + } +} + type StdioServerConfig struct { // Version of the server Version string @@ -225,6 +237,10 @@ type StdioServerConfig struct { // When specified, these tools are registered in addition to any specified toolset tools EnabledTools []string + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + // Whether to enable dynamic toolsets // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery DynamicToolsets bool @@ -282,6 +298,7 @@ func RunStdioServer(cfg StdioServerConfig) error { Token: cfg.Token, EnabledToolsets: cfg.EnabledToolsets, EnabledTools: cfg.EnabledTools, + EnabledFeatures: cfg.EnabledFeatures, DynamicToolsets: cfg.DynamicToolsets, ReadOnly: cfg.ReadOnly, Translator: t, diff --git a/pkg/github/actions.go b/pkg/github/actions.go index e9c7c11a8..f29f75e99 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -27,6 +27,7 @@ const ( // ListWorkflows creates a tool to list workflows in a repository func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflows", Description: t("TOOL_LIST_WORKFLOWS_DESCRIPTION", "List workflows in a repository"), @@ -97,6 +98,7 @@ func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { // ListWorkflowRuns creates a tool to list workflow runs for a specific workflow func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflow_runs", Description: t("TOOL_LIST_WORKFLOW_RUNS_DESCRIPTION", "List workflow runs for a specific workflow"), @@ -250,6 +252,7 @@ func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool // RunWorkflow creates a tool to run an Actions workflow func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "run_workflow", Description: t("TOOL_RUN_WORKFLOW_DESCRIPTION", "Run an Actions workflow by workflow ID or filename"), @@ -361,6 +364,7 @@ func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { // GetWorkflowRun creates a tool to get details of a specific workflow run func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_workflow_run", Description: t("TOOL_GET_WORKFLOW_RUN_DESCRIPTION", "Get details of a specific workflow run"), @@ -428,6 +432,7 @@ func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { // GetWorkflowRunLogs creates a tool to download logs for a specific workflow run func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_workflow_run_logs", Description: t("TOOL_GET_WORKFLOW_RUN_LOGS_DESCRIPTION", "Download logs for a specific workflow run (EXPENSIVE: downloads ALL logs as ZIP. Consider using get_job_logs with failed_only=true for debugging failed jobs)"), @@ -505,6 +510,7 @@ func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerToo // ListWorkflowJobs creates a tool to list jobs for a specific workflow run func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflow_jobs", Description: t("TOOL_LIST_WORKFLOW_JOBS_DESCRIPTION", "List jobs for a specific workflow run"), @@ -604,6 +610,7 @@ func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool // GetJobLogs creates a tool to download logs for a specific workflow job or efficiently get all failed job logs for a workflow run func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_job_logs", Description: t("TOOL_GET_JOB_LOGS_DESCRIPTION", "Download logs for a specific workflow job or efficiently get all failed job logs for a workflow run"), @@ -868,6 +875,7 @@ func downloadLogContent(ctx context.Context, logURL string, tailLines int, maxLi // RerunWorkflowRun creates a tool to re-run an entire workflow run func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "rerun_workflow_run", Description: t("TOOL_RERUN_WORKFLOW_RUN_DESCRIPTION", "Re-run an entire workflow run"), @@ -942,6 +950,7 @@ func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool // RerunFailedJobs creates a tool to re-run only the failed jobs in a workflow run func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "rerun_failed_jobs", Description: t("TOOL_RERUN_FAILED_JOBS_DESCRIPTION", "Re-run only the failed jobs in a workflow run"), @@ -1016,6 +1025,7 @@ func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { // CancelWorkflowRun creates a tool to cancel a workflow run func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "cancel_workflow_run", Description: t("TOOL_CANCEL_WORKFLOW_RUN_DESCRIPTION", "Cancel a workflow run"), @@ -1092,6 +1102,7 @@ func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool // ListWorkflowRunArtifacts creates a tool to list artifacts for a workflow run func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "list_workflow_run_artifacts", Description: t("TOOL_LIST_WORKFLOW_RUN_ARTIFACTS_DESCRIPTION", "List artifacts for a workflow run"), @@ -1171,6 +1182,7 @@ func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.Ser // DownloadWorkflowRunArtifact creates a tool to download a workflow run artifact func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "download_workflow_run_artifact", Description: t("TOOL_DOWNLOAD_WORKFLOW_RUN_ARTIFACT_DESCRIPTION", "Get download URL for a workflow run artifact"), @@ -1247,6 +1259,7 @@ func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets. // DeleteWorkflowRunLogs creates a tool to delete logs for a workflow run func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "delete_workflow_run_logs", Description: t("TOOL_DELETE_WORKFLOW_RUN_LOGS_DESCRIPTION", "Delete logs for a workflow run"), @@ -1322,6 +1335,7 @@ func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.Server // GetWorkflowRunUsage creates a tool to get usage metrics for a workflow run func GetWorkflowRunUsage(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataActions, mcp.Tool{ Name: "get_workflow_run_usage", Description: t("TOOL_GET_WORKFLOW_RUN_USAGE_DESCRIPTION", "Get usage metrics for a workflow run"), diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 518855a59..888ad4fd2 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -18,6 +18,7 @@ import ( func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataCodeSecurity, mcp.Tool{ Name: "get_code_scanning_alert", Description: t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository."), @@ -95,6 +96,7 @@ func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerT func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataCodeSecurity, mcp.Tool{ Name: "list_code_scanning_alerts", Description: t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository."), diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index e8043731a..d5e0cfee9 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -39,6 +39,7 @@ type UserDetails struct { // GetMe creates a tool to get details of the authenticated user. func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataContext, mcp.Tool{ Name: "get_me", Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."), @@ -112,6 +113,7 @@ type OrganizationTeams struct { func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataContext, mcp.Tool{ Name: "get_teams", Description: t("TOOL_GET_TEAMS_DESCRIPTION", "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials"), @@ -210,6 +212,7 @@ func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataContext, mcp.Tool{ Name: "get_team_members", Description: t("TOOL_GET_TEAM_MEMBERS_DESCRIPTION", "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials"), diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index b80fd0aa0..1508d1382 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -18,6 +18,7 @@ import ( func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataDependabot, mcp.Tool{ Name: "get_dependabot_alert", Description: t("TOOL_GET_DEPENDABOT_ALERT_DESCRIPTION", "Get details of a specific dependabot alert in a GitHub repository."), @@ -95,6 +96,7 @@ func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerToo func ListDependabotAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataDependabot, mcp.Tool{ Name: "list_dependabot_alerts", Description: t("TOOL_LIST_DEPENDABOT_ALERTS_DESCRIPTION", "List dependabot alerts in a GitHub repository."), diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 3124f6bd0..7dcc33f75 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -35,19 +35,19 @@ type ToolDependencies struct { ContentWindowSize int } -// NewTool creates a ServerTool with fully-typed ToolDependencies. +// NewTool creates a ServerTool with fully-typed ToolDependencies and toolset metadata. // This helper isolates the type assertion from `any` to `ToolDependencies`, // so tool implementations remain fully typed without assertions scattered throughout. -func NewTool[In, Out any](tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) toolsets.ServerTool { - return toolsets.NewServerTool(tool, func(d any) mcp.ToolHandlerFor[In, Out] { +func NewTool[In, Out any](toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) toolsets.ServerTool { + return toolsets.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { return handler(d.(ToolDependencies)) }) } -// NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies +// NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies and toolset metadata // for handlers that conform to mcp.ToolHandler directly. -func NewToolFromHandler(tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) toolsets.ServerTool { - return toolsets.NewServerToolFromHandler(tool, func(d any) mcp.ToolHandler { +func NewToolFromHandler(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) toolsets.ServerTool { + return toolsets.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { return handler(d.(ToolDependencies)) }) } diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 94f7f6f1b..5bbdb2b5f 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -124,6 +124,7 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "list_discussions", Description: t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository or organisation."), @@ -277,6 +278,7 @@ func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "get_discussion", Description: t("TOOL_GET_DISCUSSION_DESCRIPTION", "Get a specific discussion by ID"), @@ -381,6 +383,7 @@ func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "get_discussion_comments", Description: t("TOOL_GET_DISCUSSION_COMMENTS_DESCRIPTION", "Get comments from a discussion"), @@ -508,6 +511,7 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.Server func ListDiscussionCategories(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataDiscussions, mcp.Tool{ Name: "list_discussion_categories", Description: t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation."), diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index 9118c1c45..cc44e85f5 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -12,147 +12,213 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) +// DynamicToolDependencies contains dependencies for dynamic toolset management tools. +// It includes the managed ToolsetGroup, the server for registration, and the deps +// that will be passed to tools when they are dynamically enabled. +type DynamicToolDependencies struct { + // Server is the MCP server to register tools with + Server *mcp.Server + // ToolsetGroup contains all available tools that can be enabled dynamically + ToolsetGroup *toolsets.ToolsetGroup + // ToolDeps are the dependencies passed to tools when they are registered + ToolDeps any + // T is the translation helper function + T translations.TranslationHelperFunc +} + +// NewDynamicTool creates a ServerTool with fully-typed DynamicToolDependencies. +func NewDynamicTool(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) toolsets.ServerTool { + return toolsets.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { + return handler(d.(DynamicToolDependencies)) + }) +} + +// AllToolsetIDsEnum returns all available toolset IDs as an enum for JSON Schema. +func AllToolsetIDsEnum() []any { + toolsets := AvailableToolsets() + result := make([]any, len(toolsets)) + for i, ts := range toolsets { + result[i] = ts.ID + } + return result +} + +// ToolsetEnum returns the list of toolset IDs as an enum for JSON Schema. +// Deprecated: Use AllToolsetIDsEnum() instead. func ToolsetEnum(toolsetGroup *toolsets.ToolsetGroup) []any { - toolsetNames := make([]any, 0, len(toolsetGroup.Toolsets)) - for name := range toolsetGroup.Toolsets { - toolsetNames = append(toolsetNames, name) + toolsetIDs := toolsetGroup.ToolsetIDs() + result := make([]any, len(toolsetIDs)) + for i, id := range toolsetIDs { + result[i] = id } - return toolsetNames + return result } -func EnableToolset(s *mcp.Server, toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) toolsets.ServerTool { - return toolsets.NewServerToolLegacy(mcp.Tool{ - Name: "enable_toolset", - Description: t("TOOL_ENABLE_TOOLSET_DESCRIPTION", "Enable one of the sets of tools the GitHub MCP server provides, use get_toolset_tools and list_available_toolsets first to see what this will enable"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_ENABLE_TOOLSET_USER_TITLE", "Enable a toolset"), - // Not modifying GitHub data so no need to show a warning - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "toolset": { - Type: "string", - Description: "The name of the toolset to enable", - Enum: ToolsetEnum(toolsetGroup), +// DynamicTools returns the tools for dynamic toolset management. +// These tools allow runtime discovery and enablement of toolsets. +func DynamicTools() []toolsets.ServerTool { + return []toolsets.ServerTool{ + ListAvailableToolsets(), + GetToolsetsTools(), + EnableToolset(), + } +} + +// EnableToolset creates a tool that enables a toolset at runtime. +func EnableToolset() toolsets.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ + Name: "enable_toolset", + Description: "Enable one of the sets of tools the GitHub MCP server provides, use get_toolset_tools and list_available_toolsets first to see what this will enable", + Annotations: &mcp.ToolAnnotations{ + Title: "Enable a toolset", + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "toolset": { + Type: "string", + Description: "The name of the toolset to enable", + Enum: AllToolsetIDsEnum(), + }, }, + Required: []string{"toolset"}, }, - Required: []string{"toolset"}, }, - }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsets back to a map for JSON serialization - toolsetName, err := RequiredParam[string](args, "toolset") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolset := toolsetGroup.Toolsets[toolsetName] - if toolset == nil { - return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil - } - if toolset.Enabled { - return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil - } + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + toolsetName, err := RequiredParam[string](args, "toolset") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - toolset.Enabled = true + toolsetID := toolsets.ToolsetID(toolsetName) - // caution: this currently affects the global tools and notifies all clients: - // - // Send notification to all initialized sessions - // s.sendNotificationToAllClients("notifications/tools/list_changed", nil) - toolset.RegisterTools(s) + if !deps.ToolsetGroup.HasToolset(toolsetID) { + return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil + } - return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled", toolsetName)), nil, nil - })) -} + if deps.ToolsetGroup.IsToolsetEnabled(toolsetID) { + return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil + } -func ListAvailableToolsets(toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) toolsets.ServerTool { - return toolsets.NewServerToolLegacy(mcp.Tool{ - Name: "list_available_toolsets", - Description: t("TOOL_LIST_AVAILABLE_TOOLSETS_DESCRIPTION", "List all available toolsets this GitHub MCP server can offer, providing the enabled status of each. Use this when a task could be achieved with a GitHub tool and the currently available tools aren't enough. Call get_toolset_tools with these toolset names to discover specific tools you can call"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_AVAILABLE_TOOLSETS_USER_TITLE", "List available toolsets"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{}, + // Mark the toolset as enabled so IsToolsetEnabled returns true + deps.ToolsetGroup.EnableToolset(toolsetID) + + // Get tools for this toolset and register them with the managed deps + toolsForToolset := deps.ToolsetGroup.ToolsForToolset(toolsetID) + for _, st := range toolsForToolset { + st.RegisterFunc(deps.Server, deps.ToolDeps) + } + + return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled with %d tools", toolsetName, len(toolsForToolset))), nil, nil + } }, - }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsetGroup back to a map for JSON serialization + ) +} - payload := []map[string]string{} +// ListAvailableToolsets creates a tool that lists all available toolsets. +func ListAvailableToolsets() toolsets.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ + Name: "list_available_toolsets", + Description: "List all available toolsets this GitHub MCP server can offer, providing the enabled status of each. Use this when a task could be achieved with a GitHub tool and the currently available tools aren't enough. Call get_toolset_tools with these toolset names to discover specific tools you can call", + Annotations: &mcp.ToolAnnotations{ + Title: "List available toolsets", + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{}, + }, + }, + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + toolsetIDs := deps.ToolsetGroup.ToolsetIDs() + descriptions := deps.ToolsetGroup.ToolsetDescriptions() - for name, ts := range toolsetGroup.Toolsets { - { + payload := make([]map[string]string, 0, len(toolsetIDs)) + for _, id := range toolsetIDs { t := map[string]string{ - "name": name, - "description": ts.Description, + "name": string(id), + "description": descriptions[id], "can_enable": "true", - "currently_enabled": fmt.Sprintf("%t", ts.Enabled), + "currently_enabled": fmt.Sprintf("%t", deps.ToolsetGroup.IsToolsetEnabled(id)), } payload = append(payload, t) } - } - r, err := json.Marshal(payload) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal features: %w", err) - } + r, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal features: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - })) + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func GetToolsetsTools(toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) toolsets.ServerTool { - return toolsets.NewServerToolLegacy(mcp.Tool{ - Name: "get_toolset_tools", - Description: t("TOOL_GET_TOOLSET_TOOLS_DESCRIPTION", "Lists all the capabilities that are enabled with the specified toolset, use this to get clarity on whether enabling a toolset would help you to complete a task"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_TOOLSET_TOOLS_USER_TITLE", "List all tools in a toolset"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "toolset": { - Type: "string", - Description: "The name of the toolset you want to get the tools for", - Enum: ToolsetEnum(toolsetGroup), +// GetToolsetsTools creates a tool that lists all tools in a specific toolset. +func GetToolsetsTools() toolsets.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ + Name: "get_toolset_tools", + Description: "Lists all the capabilities that are enabled with the specified toolset, use this to get clarity on whether enabling a toolset would help you to complete a task", + Annotations: &mcp.ToolAnnotations{ + Title: "List all tools in a toolset", + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "toolset": { + Type: "string", + Description: "The name of the toolset you want to get the tools for", + Enum: AllToolsetIDsEnum(), + }, }, + Required: []string{"toolset"}, }, - Required: []string{"toolset"}, }, - }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsetGroup back to a map for JSON serialization - toolsetName, err := RequiredParam[string](args, "toolset") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolset := toolsetGroup.Toolsets[toolsetName] - if toolset == nil { - return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil - } - payload := []map[string]string{} - - for _, st := range toolset.GetAvailableTools() { - tool := map[string]string{ - "name": st.Tool.Name, - "description": st.Tool.Description, - "can_enable": "true", - "toolset": toolsetName, + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + toolsetName, err := RequiredParam[string](args, "toolset") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - payload = append(payload, tool) - } - r, err := json.Marshal(payload) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal features: %w", err) - } + toolsetID := toolsets.ToolsetID(toolsetName) + + if !deps.ToolsetGroup.HasToolset(toolsetID) { + return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil + } + + // Get all tools for this toolset (ignoring current filters for discovery) + toolsInToolset := deps.ToolsetGroup.ToolsForToolset(toolsetID) + payload := make([]map[string]string, 0, len(toolsInToolset)) + + for _, st := range toolsInToolset { + tool := map[string]string{ + "name": st.Tool.Name, + "description": st.Tool.Description, + "can_enable": "true", + "toolset": toolsetName, + } + payload = append(payload, tool) + } + + r, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal features: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - })) + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } diff --git a/pkg/github/gists.go b/pkg/github/gists.go index baca42399..03e5e1bc8 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -18,6 +18,7 @@ import ( // ListGists creates a tool to list gists for a user func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "list_gists", Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), @@ -105,6 +106,7 @@ func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { // GetGist creates a tool to get the content of a gist func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "get_gist", Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), @@ -163,6 +165,7 @@ func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { // CreateGist creates a tool to create a new gist func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "create_gist", Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), @@ -266,6 +269,7 @@ func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { // UpdateGist creates a tool to edit an existing gist func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataGists, mcp.Tool{ Name: "update_gist", Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), diff --git a/pkg/github/git.go b/pkg/github/git.go index c5fdded7c..e619afc34 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -40,6 +40,7 @@ type TreeResponse struct { // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataGit, mcp.Tool{ Name: "get_repository_tree", Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 142bdd421..1d0e3b2d5 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -263,6 +263,7 @@ Options are: WithPagination(schema) return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "issue_read", Description: t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository."), @@ -546,6 +547,7 @@ func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, // ListIssueTypes creates a tool to list defined issue types for an organization. This can be used to understand supported issue type values for creating or updating issues. func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "list_issue_types", Description: t("TOOL_LIST_ISSUE_TYPES_FOR_ORG", "List supported issue types for repository owner (organization)."), @@ -602,6 +604,7 @@ func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { // AddIssueComment creates a tool to add a comment to an issue. func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "add_issue_comment", Description: t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to a specific issue in a GitHub repository. Use this tool to add comments to pull requests as well (in this case pass pull request number as issue_number), but only if user is not asking specifically to add review comments."), @@ -686,6 +689,7 @@ func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { // SubIssueWrite creates a tool to add a sub-issue to a parent issue. func SubIssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "sub_issue_write", Description: t("TOOL_SUB_ISSUE_WRITE_DESCRIPTION", "Add a sub-issue to a parent issue in a GitHub repository."), @@ -956,6 +960,7 @@ func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "search_issues", Description: t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues in GitHub repositories using issues search syntax already scoped to is:issue"), @@ -976,6 +981,7 @@ func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { // IssueWrite creates a tool to create a new or update an existing issue in a GitHub repository. func IssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "issue_write", Description: t("TOOL_ISSUE_WRITE_DESCRIPTION", "Create a new or update an existing issue in a GitHub repository."), @@ -1376,6 +1382,7 @@ func ListIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { WithCursorPagination(schema) return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "list_issues", Description: t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter."), @@ -1611,6 +1618,7 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) toolsets.ServerT } return NewTool( + ToolsetMetadataIssues, mcp.Tool{ Name: "assign_copilot_to_issue", Description: t("TOOL_ASSIGN_COPILOT_TO_ISSUE_DESCRIPTION", description.String()), @@ -1797,8 +1805,10 @@ func parseISOTimestamp(timestamp string) (time.Time, error) { return time.Time{}, fmt.Errorf("invalid ISO 8601 timestamp: %s (supported formats: YYYY-MM-DDThh:mm:ssZ or YYYY-MM-DD)", timestamp) } -func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, mcp.PromptHandler) { - return mcp.Prompt{ +func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) toolsets.ServerPrompt { + return toolsets.NewServerPrompt( + ToolsetMetadataIssues, + mcp.Prompt{ Name: "AssignCodingAgent", Description: t("PROMPT_ASSIGN_CODING_AGENT_DESCRIPTION", "Assign GitHub Coding Agent to multiple tasks in a GitHub repository."), Arguments: []*mcp.PromptArgument{ @@ -1808,7 +1818,8 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, Required: true, }, }, - }, func(_ context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, + func(_ context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { repo := request.Params.Arguments["repo"] messages := []*mcp.PromptMessage{ @@ -1852,5 +1863,6 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, return &mcp.GetPromptResult{ Messages: messages, }, nil - } + }, + ) } diff --git a/pkg/github/labels.go b/pkg/github/labels.go index 0c18c83e4..a98468fae 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -18,6 +18,7 @@ import ( // GetLabel retrieves a specific label by name from a GitHub repository func GetLabel(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetLabels, mcp.Tool{ Name: "get_label", Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), @@ -112,6 +113,7 @@ func GetLabel(t translations.TranslationHelperFunc) toolsets.ServerTool { // ListLabels lists labels from a repository func ListLabels(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetLabels, mcp.Tool{ Name: "list_label", Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), @@ -203,6 +205,7 @@ func ListLabels(t translations.TranslationHelperFunc) toolsets.ServerTool { // LabelWrite handles create, update, and delete operations for GitHub labels func LabelWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetLabels, mcp.Tool{ Name: "label_write", Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 23d63c946..4eb2d7b5b 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -27,6 +27,7 @@ const ( // ListNotifications creates a tool to list notifications for the current user. func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "list_notifications", Description: t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "Lists all GitHub notifications for the authenticated user, including unread notifications, mentions, review requests, assignments, and updates on issues or pull requests. Use this tool whenever the user asks what to work on next, requests a summary of their GitHub activity, wants to see pending reviews, or needs to check for new updates or tasks. This tool is the primary way to discover actionable items, reminders, and outstanding work on GitHub. Always call this tool when asked what to work on next, what is pending, or what needs attention in GitHub."), @@ -164,6 +165,7 @@ func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool // DismissNotification creates a tool to mark a notification as read/done. func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "dismiss_notification", Description: t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done"), @@ -246,6 +248,7 @@ func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTo // MarkAllNotificationsRead creates a tool to mark all notifications as read. func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "mark_all_notifications_read", Description: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read"), @@ -338,6 +341,7 @@ func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.Ser // GetNotificationDetails creates a tool to get details for a specific notification. func GetNotificationDetails(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "get_notification_details", Description: t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first."), @@ -407,6 +411,7 @@ const ( // ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) func ManageNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "manage_notification_subscription", Description: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription."), @@ -503,6 +508,7 @@ const ( // ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataNotifications, mcp.Tool{ Name: "manage_repository_notification_subscription", Description: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository."), diff --git a/pkg/github/projects.go b/pkg/github/projects.go index ca26e2550..a12aca7be 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -27,6 +27,7 @@ const ( func ListProjects(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "list_projects", Description: t("TOOL_LIST_PROJECTS_DESCRIPTION", `List Projects for a user or organization`), @@ -145,6 +146,7 @@ func ListProjects(t translations.TranslationHelperFunc) toolsets.ServerTool { func GetProject(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "get_project", Description: t("TOOL_GET_PROJECT_DESCRIPTION", "Get Project for a user or org"), @@ -234,6 +236,7 @@ func GetProject(t translations.TranslationHelperFunc) toolsets.ServerTool { func ListProjectFields(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "list_project_fields", Description: t("TOOL_LIST_PROJECT_FIELDS_DESCRIPTION", "List Project fields for a user or org"), @@ -341,6 +344,7 @@ func ListProjectFields(t translations.TranslationHelperFunc) toolsets.ServerTool func GetProjectField(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "get_project_field", Description: t("TOOL_GET_PROJECT_FIELD_DESCRIPTION", "Get Project field for a user or org"), @@ -434,6 +438,7 @@ func GetProjectField(t translations.TranslationHelperFunc) toolsets.ServerTool { func ListProjectItems(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "list_project_items", Description: t("TOOL_LIST_PROJECT_ITEMS_DESCRIPTION", `Search project items with advanced filtering`), @@ -571,6 +576,7 @@ func ListProjectItems(t translations.TranslationHelperFunc) toolsets.ServerTool func GetProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "get_project_item", Description: t("TOOL_GET_PROJECT_ITEM_DESCRIPTION", "Get a specific Project item for a user or org"), @@ -678,6 +684,7 @@ func GetProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { func AddProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "add_project_item", Description: t("TOOL_ADD_PROJECT_ITEM_DESCRIPTION", "Add a specific Project item for a user or org"), @@ -790,6 +797,7 @@ func AddProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { func UpdateProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "update_project_item", Description: t("TOOL_UPDATE_PROJECT_ITEM_DESCRIPTION", "Update a specific Project item for a user or org"), @@ -903,6 +911,7 @@ func UpdateProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool func DeleteProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataProjects, mcp.Tool{ Name: "delete_project_item", Description: t("TOOL_DELETE_PROJECT_ITEM_DESCRIPTION", "Delete a specific Project item for a user or org"), diff --git a/pkg/github/prompts.go b/pkg/github/prompts.go new file mode 100644 index 000000000..82d7bf514 --- /dev/null +++ b/pkg/github/prompts.go @@ -0,0 +1,16 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/translations" +) + +// AllPrompts returns all prompts with their embedded toolset metadata. +// Prompt functions return ServerPrompt directly with toolset info. +func AllPrompts(t translations.TranslationHelperFunc) []toolsets.ServerPrompt { + return []toolsets.ServerPrompt{ + // Issue prompts + AssignCodingAgentPrompt(t), + IssueToFixWorkflowPrompt(t), + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index bfe870775..229e20e57 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -58,6 +58,7 @@ Possible options: WithPagination(schema) return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "pull_request_read", Description: t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository."), @@ -430,6 +431,7 @@ func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "create_pull_request", Description: t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository."), @@ -582,6 +584,7 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "update_pull_request", Description: t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository."), @@ -864,6 +867,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool WithPagination(schema) return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "list_pull_requests", Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead."), @@ -1000,6 +1004,7 @@ func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "merge_pull_request", Description: t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request in a GitHub repository."), @@ -1118,6 +1123,7 @@ func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerToo WithPagination(schema) return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "search_pull_requests", Description: t("TOOL_SEARCH_PULL_REQUESTS_DESCRIPTION", "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr"), @@ -1161,6 +1167,7 @@ func UpdatePullRequestBranch(t translations.TranslationHelperFunc) toolsets.Serv } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "update_pull_request_branch", Description: t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update the branch of a pull request with the latest changes from the base branch."), @@ -1282,6 +1289,7 @@ func PullRequestReviewWrite(t translations.TranslationHelperFunc) toolsets.Serve } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "pull_request_review_write", Description: t("TOOL_PULL_REQUEST_REVIEW_WRITE_DESCRIPTION", `Create and/or submit, delete review of a pull request. @@ -1612,6 +1620,7 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.Se } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "add_comment_to_pending_review", Description: t("TOOL_ADD_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add review comment to the requester's latest pending pull request review. A pending review needs to already exist to call this (check with the user if not sure)."), @@ -1764,6 +1773,7 @@ func RequestCopilotReview(t translations.TranslationHelperFunc) toolsets.ServerT } return NewTool( + ToolsetMetadataPullRequests, mcp.Tool{ Name: "request_copilot_review", Description: t("TOOL_REQUEST_COPILOT_REVIEW_DESCRIPTION", "Request a GitHub Copilot code review for a pull request. Use this for automated feedback on pull requests, usually before requesting a human reviewer."), diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index de5aaea5e..81e5c3a8c 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -21,6 +21,7 @@ import ( func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_commit", Description: t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository"), @@ -119,6 +120,7 @@ func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { // ListCommits creates a tool to get commits of a branch in a repository. func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_commits", Description: t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository. Returns at least 30 results per page by default, but can return more if specified using the perPage parameter (up to 100)."), @@ -227,6 +229,7 @@ func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { // ListBranches creates a tool to list branches in a GitHub repository. func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_branches", Description: t("TOOL_LIST_BRANCHES_DESCRIPTION", "List branches in a GitHub repository"), @@ -314,6 +317,7 @@ func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "create_or_update_file", Description: t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations."), @@ -441,6 +445,7 @@ func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerToo // CreateRepository creates a tool to create a new GitHub repository. func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "create_repository", Description: t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account or specified organization"), @@ -547,6 +552,7 @@ func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_file_contents", Description: t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository"), @@ -766,6 +772,7 @@ func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { // ForkRepository creates a tool to fork a repository. func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "fork_repository", Description: t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization"), @@ -864,6 +871,7 @@ func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { // both of which suit an LLM well. func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "delete_file", Description: t("TOOL_DELETE_FILE_DESCRIPTION", "Delete a file from a GitHub repository"), @@ -1049,6 +1057,7 @@ func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { // CreateBranch creates a tool to create a new branch. func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "create_branch", Description: t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository"), @@ -1162,6 +1171,7 @@ func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "push_files", Description: t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit"), @@ -1346,6 +1356,7 @@ func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { // ListTags creates a tool to list tags in a GitHub repository. func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_tags", Description: t("TOOL_LIST_TAGS_DESCRIPTION", "List git tags in a GitHub repository"), @@ -1425,6 +1436,7 @@ func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { // GetTag creates a tool to get details about a specific tag in a GitHub repository. func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_tag", Description: t("TOOL_GET_TAG_DESCRIPTION", "Get details about a specific git tag in a GitHub repository"), @@ -1523,6 +1535,7 @@ func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { // ListReleases creates a tool to list releases in a GitHub repository. func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "list_releases", Description: t("TOOL_LIST_RELEASES_DESCRIPTION", "List releases in a GitHub repository"), @@ -1598,6 +1611,7 @@ func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { // GetLatestRelease creates a tool to get the latest release in a GitHub repository. func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_latest_release", Description: t("TOOL_GET_LATEST_RELEASE_DESCRIPTION", "Get the latest release in a GitHub repository"), @@ -1663,6 +1677,7 @@ func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool func GetReleaseByTag(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "get_release_by_tag", Description: t("TOOL_GET_RELEASE_BY_TAG_DESCRIPTION", "Get a specific release by its tag name in a GitHub repository"), @@ -1876,6 +1891,7 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner // ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user. func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataStargazers, mcp.Tool{ Name: "list_starred_repositories", Description: t("TOOL_LIST_STARRED_REPOSITORIES_DESCRIPTION", "List starred repositories"), @@ -2008,6 +2024,7 @@ func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.Serv // StarRepository creates a tool to star a repository. func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataStargazers, mcp.Tool{ Name: "star_repository", Description: t("TOOL_STAR_REPOSITORY_DESCRIPTION", "Star a GitHub repository"), @@ -2073,6 +2090,7 @@ func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { // UnstarRepository creates a tool to unstar a repository. func UnstarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataStargazers, mcp.Tool{ Name: "unstar_repository", Description: t("TOOL_UNSTAR_REPOSITORY_DESCRIPTION", "Unstar a GitHub repository"), diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 5dea9f4e9..d8fd13963 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -29,53 +30,68 @@ var ( ) // GetRepositoryResourceContent defines the resource template and handler for getting repository content. -func GetRepositoryResourceContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +func GetRepositoryResourceContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { + return toolsets.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content", - URITemplate: repositoryResourceContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_DESCRIPTION", "Repository Content"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate), + ) } // GetRepositoryResourceBranchContent defines the resource template and handler for getting repository content for a branch. -func GetRepositoryResourceBranchContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +func GetRepositoryResourceBranchContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { + return toolsets.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_branch", - URITemplate: repositoryResourceBranchContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceBranchContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_BRANCH_DESCRIPTION", "Repository Content for specific branch"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate), + ) } // GetRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. -func GetRepositoryResourceCommitContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +func GetRepositoryResourceCommitContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { + return toolsets.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_commit", - URITemplate: repositoryResourceCommitContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceCommitContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_COMMIT_DESCRIPTION", "Repository Content for specific commit"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate) + RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate), + ) } // GetRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. -func GetRepositoryResourceTagContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +func GetRepositoryResourceTagContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { + return toolsets.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_tag", - URITemplate: repositoryResourceTagContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceTagContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_TAG_DESCRIPTION", "Repository Content for specific tag"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate) + RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate), + ) } // GetRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. -func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { + return toolsets.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_pr", - URITemplate: repositoryResourcePrContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourcePrContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_PR_DESCRIPTION", "Repository Content for specific pull request"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate) + RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate), + ) } // RepositoryResourceContentsHandler returns a handler function for repository content requests. diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index 113f46d89..1b4120ff0 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -47,8 +47,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo:///repo/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "owner is required", @@ -67,8 +66,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner//refs/heads/main/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceBranchContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceBranchContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "repo is required", @@ -87,8 +85,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/contents/data.png", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeBlob, expectedResult: &mcp.ReadResourceResult{ @@ -112,8 +109,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -139,8 +135,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/contents/pkg/github/actions.go", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -164,8 +159,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/refs/heads/main/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceBranchContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceBranchContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -189,8 +183,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/refs/tags/v1.0.0/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceTagContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceTagContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -214,8 +207,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/sha/abc123/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceCommitContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceCommitContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -247,8 +239,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/refs/pull/42/head/contents/README.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourcePrContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourcePrContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -271,8 +262,7 @@ func Test_repositoryResourceContents(t *testing.T) { ), uri: "repo://owner/repo/contents/nonexistent.md", handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + return GetRepositoryResourceContent(getClient, getRawClient, t).Handler }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "404 Not Found", diff --git a/pkg/github/resources.go b/pkg/github/resources.go new file mode 100644 index 000000000..f0b07e831 --- /dev/null +++ b/pkg/github/resources.go @@ -0,0 +1,20 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/translations" +) + +// AllResources returns all resource templates with their embedded toolset metadata. +// Resource template functions return ServerResourceTemplate directly with toolset info. +func AllResources(t translations.TranslationHelperFunc, getClient GetClientFn, getRawClient raw.GetRawClientFn) []toolsets.ServerResourceTemplate { + return []toolsets.ServerResourceTemplate{ + // Repository resources + GetRepositoryResourceContent(getClient, getRawClient, t), + GetRepositoryResourceBranchContent(getClient, getRawClient, t), + GetRepositoryResourceCommitContent(getClient, getRawClient, t), + GetRepositoryResourceTagContent(getClient, getRawClient, t), + GetRepositoryResourcePrContent(getClient, getRawClient, t), + } +} diff --git a/pkg/github/search.go b/pkg/github/search.go index eaaf49369..730435eba 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -46,6 +46,7 @@ func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerToo WithPagination(schema) return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), @@ -189,6 +190,7 @@ func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataRepos, mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), @@ -373,6 +375,7 @@ func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataUsers, mcp.Tool{ Name: "search_users", Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), @@ -413,6 +416,7 @@ func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { WithPagination(schema) return NewTool( + ToolsetMetadataOrgs, mcp.Tool{ Name: "search_orgs", Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index fa5618791..7e842ded1 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -18,6 +18,7 @@ import ( func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataSecretProtection, mcp.Tool{ Name: "get_secret_scanning_alert", Description: t("TOOL_GET_SECRET_SCANNING_ALERT_DESCRIPTION", "Get details of a specific secret scanning alert in a GitHub repository."), @@ -95,6 +96,7 @@ func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.Serve func ListSecretScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataSecretProtection, mcp.Tool{ Name: "list_secret_scanning_alerts", Description: t("TOOL_LIST_SECRET_SCANNING_ALERTS_DESCRIPTION", "List secret scanning alerts in a GitHub repository."), diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index 8c7df4265..cf507d17a 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -17,6 +17,7 @@ import ( func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "list_global_security_advisories", Description: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub."), @@ -208,6 +209,7 @@ func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "list_repository_security_advisories", Description: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub repository."), @@ -312,6 +314,7 @@ func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) tool func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "get_global_security_advisory", Description: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory"), @@ -369,6 +372,7 @@ func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.Se func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { return NewTool( + ToolsetMetadataSecurityAdvisories, mcp.Tool{ Name: "list_org_repository_security_advisories", Description: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub organization."), diff --git a/pkg/github/server.go b/pkg/github/server.go index e74596906..7432466d1 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -18,14 +18,16 @@ import ( func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { if opts == nil { - // Add default options - opts = &mcp.ServerOptions{ - HasTools: true, - HasResources: true, - HasPrompts: true, - } + opts = &mcp.ServerOptions{} } + // Always advertise capabilities so clients know we support list_changed notifications. + // This is important for dynamic toolsets mode where we start with few tools + // and add more at runtime. + opts.HasTools = true + opts.HasResources = true + opts.HasPrompts = true + // Create a new MCP server s := mcp.NewServer(&mcp.Implementation{ Name: "github-mcp-server", diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 02ec66e8a..1fe23dfd2 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -5,117 +5,110 @@ import ( "fmt" "strings" - "github.com/github/github-mcp-server/pkg/lockdown" - "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" - "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" ) type GetClientFn func(context.Context) (*github.Client, error) type GetGQLClientFn func(context.Context) (*githubv4.Client, error) -// ToolsetMetadata holds metadata for a toolset including its ID and description -type ToolsetMetadata struct { - ID string - Description string -} - +// Toolset metadata constants - these define all available toolsets and their descriptions. +// Tools use these constants to declare which toolset they belong to. var ( - ToolsetMetadataAll = ToolsetMetadata{ + ToolsetMetadataAll = toolsets.ToolsetMetadata{ ID: "all", Description: "Special toolset that enables all available toolsets", } - ToolsetMetadataDefault = ToolsetMetadata{ + ToolsetMetadataDefault = toolsets.ToolsetMetadata{ ID: "default", Description: "Special toolset that enables the default toolset configuration. When no toolsets are specified, this is the set that is enabled", } - ToolsetMetadataContext = ToolsetMetadata{ + ToolsetMetadataContext = toolsets.ToolsetMetadata{ ID: "context", Description: "Tools that provide context about the current user and GitHub context you are operating in", } - ToolsetMetadataRepos = ToolsetMetadata{ + ToolsetMetadataRepos = toolsets.ToolsetMetadata{ ID: "repos", Description: "GitHub Repository related tools", } - ToolsetMetadataGit = ToolsetMetadata{ + ToolsetMetadataGit = toolsets.ToolsetMetadata{ ID: "git", Description: "GitHub Git API related tools for low-level Git operations", } - ToolsetMetadataIssues = ToolsetMetadata{ + ToolsetMetadataIssues = toolsets.ToolsetMetadata{ ID: "issues", Description: "GitHub Issues related tools", } - ToolsetMetadataPullRequests = ToolsetMetadata{ + ToolsetMetadataPullRequests = toolsets.ToolsetMetadata{ ID: "pull_requests", Description: "GitHub Pull Request related tools", } - ToolsetMetadataUsers = ToolsetMetadata{ + ToolsetMetadataUsers = toolsets.ToolsetMetadata{ ID: "users", Description: "GitHub User related tools", } - ToolsetMetadataOrgs = ToolsetMetadata{ + ToolsetMetadataOrgs = toolsets.ToolsetMetadata{ ID: "orgs", Description: "GitHub Organization related tools", } - ToolsetMetadataActions = ToolsetMetadata{ + ToolsetMetadataActions = toolsets.ToolsetMetadata{ ID: "actions", Description: "GitHub Actions workflows and CI/CD operations", } - ToolsetMetadataCodeSecurity = ToolsetMetadata{ + ToolsetMetadataCodeSecurity = toolsets.ToolsetMetadata{ ID: "code_security", Description: "Code security related tools, such as GitHub Code Scanning", } - ToolsetMetadataSecretProtection = ToolsetMetadata{ + ToolsetMetadataSecretProtection = toolsets.ToolsetMetadata{ ID: "secret_protection", Description: "Secret protection related tools, such as GitHub Secret Scanning", } - ToolsetMetadataDependabot = ToolsetMetadata{ + ToolsetMetadataDependabot = toolsets.ToolsetMetadata{ ID: "dependabot", Description: "Dependabot tools", } - ToolsetMetadataNotifications = ToolsetMetadata{ + ToolsetMetadataNotifications = toolsets.ToolsetMetadata{ ID: "notifications", Description: "GitHub Notifications related tools", } - ToolsetMetadataExperiments = ToolsetMetadata{ + ToolsetMetadataExperiments = toolsets.ToolsetMetadata{ ID: "experiments", Description: "Experimental features that are not considered stable yet", } - ToolsetMetadataDiscussions = ToolsetMetadata{ + ToolsetMetadataDiscussions = toolsets.ToolsetMetadata{ ID: "discussions", Description: "GitHub Discussions related tools", } - ToolsetMetadataGists = ToolsetMetadata{ + ToolsetMetadataGists = toolsets.ToolsetMetadata{ ID: "gists", Description: "GitHub Gist related tools", } - ToolsetMetadataSecurityAdvisories = ToolsetMetadata{ + ToolsetMetadataSecurityAdvisories = toolsets.ToolsetMetadata{ ID: "security_advisories", Description: "Security advisories related tools", } - ToolsetMetadataProjects = ToolsetMetadata{ + ToolsetMetadataProjects = toolsets.ToolsetMetadata{ ID: "projects", Description: "GitHub Projects related tools", } - ToolsetMetadataStargazers = ToolsetMetadata{ + ToolsetMetadataStargazers = toolsets.ToolsetMetadata{ ID: "stargazers", Description: "GitHub Stargazers related tools", } - ToolsetMetadataDynamic = ToolsetMetadata{ + ToolsetMetadataDynamic = toolsets.ToolsetMetadata{ ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", } - ToolsetLabels = ToolsetMetadata{ + ToolsetLabels = toolsets.ToolsetMetadata{ ID: "labels", Description: "GitHub Labels related tools", } ) -func AvailableTools() []ToolsetMetadata { - return []ToolsetMetadata{ +func AvailableToolsets() []toolsets.ToolsetMetadata { + return []toolsets.ToolsetMetadata{ ToolsetMetadataContext, ToolsetMetadataRepos, ToolsetMetadataIssues, @@ -139,10 +132,10 @@ func AvailableTools() []ToolsetMetadata { } // GetValidToolsetIDs returns a map of all valid toolset IDs for quick lookup -func GetValidToolsetIDs() map[string]bool { - validIDs := make(map[string]bool) - for _, tool := range AvailableTools() { - validIDs[tool.ID] = true +func GetValidToolsetIDs() map[toolsets.ToolsetID]bool { + validIDs := make(map[toolsets.ToolsetID]bool) + for _, toolset := range AvailableToolsets() { + validIDs[toolset.ID] = true } // Add special keywords validIDs[ToolsetMetadataAll.ID] = true @@ -150,8 +143,8 @@ func GetValidToolsetIDs() map[string]bool { return validIDs } -func GetDefaultToolsetIDs() []string { - return []string{ +func GetDefaultToolsetIDs() []toolsets.ToolsetID { + return []toolsets.ToolsetID{ ToolsetMetadataContext.ID, ToolsetMetadataRepos.ID, ToolsetMetadataIssues.ID, @@ -160,274 +153,138 @@ func GetDefaultToolsetIDs() []string { } } -func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags, cache *lockdown.RepoAccessCache) *toolsets.ToolsetGroup { - tsg := toolsets.NewToolsetGroup(readOnly) - - // Create the dependencies struct that will be passed to all tool handlers - deps := ToolDependencies{ - GetClient: getClient, - GetGQLClient: getGQLClient, - GetRawClient: getRawClient, - RepoAccessCache: cache, - T: t, - Flags: flags, - ContentWindowSize: contentWindowSize, - } - - // Define all available features with their default state (disabled) - // Create toolsets - repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). - SetDependencies(deps). - AddReadTools( - SearchRepositories(t), - GetFileContents(t), - ListCommits(t), - SearchCode(t), - GetCommit(t), - ListBranches(t), - ListTags(t), - GetTag(t), - ListReleases(t), - GetLatestRelease(t), - GetReleaseByTag(t), - ). - AddWriteTools( - CreateOrUpdateFile(t), - CreateRepository(t), - ForkRepository(t), - CreateBranch(t), - PushFiles(t), - DeleteFile(t), - ). - AddResourceTemplates( - toolsets.NewServerResourceTemplate(GetRepositoryResourceContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceBranchContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceCommitContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceTagContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourcePrContent(getClient, getRawClient, t)), - ) - git := toolsets.NewToolset(ToolsetMetadataGit.ID, ToolsetMetadataGit.Description). - SetDependencies(deps). - AddReadTools( - GetRepositoryTree(t), - ) - issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). - SetDependencies(deps). - AddReadTools( - IssueRead(t), - SearchIssues(t), - ListIssues(t), - ListIssueTypes(t), - GetLabel(t), - ). - AddWriteTools( - IssueWrite(t), - AddIssueComment(t), - AssignCopilotToIssue(t), - SubIssueWrite(t), - ).AddPrompts( - toolsets.NewServerPrompt(AssignCodingAgentPrompt(t)), - toolsets.NewServerPrompt(IssueToFixWorkflowPrompt(t)), - ) - users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). - SetDependencies(deps). - AddReadTools( - SearchUsers(t), - ) - orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). - SetDependencies(deps). - AddReadTools( - SearchOrgs(t), - ) - pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). - SetDependencies(deps). - AddReadTools( - PullRequestRead(t), - ListPullRequests(t), - SearchPullRequests(t), - ). - AddWriteTools( - MergePullRequest(t), - UpdatePullRequestBranch(t), - CreatePullRequest(t), - UpdatePullRequest(t), - RequestCopilotReview(t), - // Reviews - PullRequestReviewWrite(t), - AddCommentToPendingReview(t), - ) - codeSecurity := toolsets.NewToolset(ToolsetMetadataCodeSecurity.ID, ToolsetMetadataCodeSecurity.Description). - SetDependencies(deps). - AddReadTools( - GetCodeScanningAlert(t), - ListCodeScanningAlerts(t), - ) - secretProtection := toolsets.NewToolset(ToolsetMetadataSecretProtection.ID, ToolsetMetadataSecretProtection.Description). - SetDependencies(deps). - AddReadTools( - GetSecretScanningAlert(t), - ListSecretScanningAlerts(t), - ) - dependabot := toolsets.NewToolset(ToolsetMetadataDependabot.ID, ToolsetMetadataDependabot.Description). - SetDependencies(deps). - AddReadTools( - GetDependabotAlert(t), - ListDependabotAlerts(t), - ) - - notifications := toolsets.NewToolset(ToolsetMetadataNotifications.ID, ToolsetMetadataNotifications.Description). - SetDependencies(deps). - AddReadTools( - ListNotifications(t), - GetNotificationDetails(t), - ). - AddWriteTools( - DismissNotification(t), - MarkAllNotificationsRead(t), - ManageNotificationSubscription(t), - ManageRepositoryNotificationSubscription(t), - ) - - discussions := toolsets.NewToolset(ToolsetMetadataDiscussions.ID, ToolsetMetadataDiscussions.Description). - SetDependencies(deps). - AddReadTools( - ListDiscussions(t), - GetDiscussion(t), - GetDiscussionComments(t), - ListDiscussionCategories(t), - ) - - actions := toolsets.NewToolset(ToolsetMetadataActions.ID, ToolsetMetadataActions.Description). - SetDependencies(deps). - AddReadTools( - ListWorkflows(t), - ListWorkflowRuns(t), - GetWorkflowRun(t), - GetWorkflowRunLogs(t), - ListWorkflowJobs(t), - GetJobLogs(t), - ListWorkflowRunArtifacts(t), - DownloadWorkflowRunArtifact(t), - GetWorkflowRunUsage(t), - ). - AddWriteTools( - RunWorkflow(t), - RerunWorkflowRun(t), - RerunFailedJobs(t), - CancelWorkflowRun(t), - DeleteWorkflowRunLogs(t), - ) - - securityAdvisories := toolsets.NewToolset(ToolsetMetadataSecurityAdvisories.ID, ToolsetMetadataSecurityAdvisories.Description). - SetDependencies(deps). - AddReadTools( - ListGlobalSecurityAdvisories(t), - GetGlobalSecurityAdvisory(t), - ListRepositorySecurityAdvisories(t), - ListOrgRepositorySecurityAdvisories(t), - ) - - // // Keep experiments alive so the system doesn't error out when it's always enabled - experiments := toolsets.NewToolset(ToolsetMetadataExperiments.ID, ToolsetMetadataExperiments.Description). - SetDependencies(deps) - - contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). - SetDependencies(deps). - AddReadTools( - GetMe(t), - GetTeams(t), - GetTeamMembers(t), - ) - - gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). - SetDependencies(deps). - AddReadTools( - ListGists(t), - GetGist(t), - ). - AddWriteTools( - CreateGist(t), - UpdateGist(t), - ) - - projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description). - SetDependencies(deps). - AddReadTools( - ListProjects(t), - GetProject(t), - ListProjectFields(t), - GetProjectField(t), - ListProjectItems(t), - GetProjectItem(t), - ). - AddWriteTools( - AddProjectItem(t), - DeleteProjectItem(t), - UpdateProjectItem(t), - ) - stargazers := toolsets.NewToolset(ToolsetMetadataStargazers.ID, ToolsetMetadataStargazers.Description). - SetDependencies(deps). - AddReadTools( - ListStarredRepositories(t), - ). - AddWriteTools( - StarRepository(t), - UnstarRepository(t), - ) - labels := toolsets.NewToolset(ToolsetLabels.ID, ToolsetLabels.Description). - SetDependencies(deps). - AddReadTools( - // get - GetLabel(t), - // list labels on repo or issue - ListLabels(t), - ). - AddWriteTools( - // create or update - LabelWrite(t), - ) - - // Add toolsets to the group - tsg.AddToolset(contextTools) - tsg.AddToolset(repos) - tsg.AddToolset(git) - tsg.AddToolset(issues) - tsg.AddToolset(orgs) - tsg.AddToolset(users) - tsg.AddToolset(pullRequests) - tsg.AddToolset(actions) - tsg.AddToolset(codeSecurity) - tsg.AddToolset(dependabot) - tsg.AddToolset(secretProtection) - tsg.AddToolset(notifications) - tsg.AddToolset(experiments) - tsg.AddToolset(discussions) - tsg.AddToolset(gists) - tsg.AddToolset(securityAdvisories) - tsg.AddToolset(projects) - tsg.AddToolset(stargazers) - tsg.AddToolset(labels) - - tsg.AddDeprecatedToolAliases(DeprecatedToolAliases) - - return tsg -} - -// InitDynamicToolset creates a dynamic toolset that can be used to enable other toolsets, and so requires the server and toolset group as arguments -// -//nolint:unused -func InitDynamicToolset(s *mcp.Server, tsg *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) *toolsets.Toolset { - // Create a new dynamic toolset - // Need to add the dynamic toolset last so it can be used to enable other toolsets - dynamicToolSelection := toolsets.NewToolset(ToolsetMetadataDynamic.ID, ToolsetMetadataDynamic.Description). - AddReadTools( - ListAvailableToolsets(tsg, t), - GetToolsetsTools(tsg, t), - EnableToolset(s, tsg, t), - ) - - dynamicToolSelection.Enabled = true - return dynamicToolSelection +// AllTools returns all tools with their embedded toolset metadata. +// Tool functions return ServerTool directly with toolset info. +func AllTools(t translations.TranslationHelperFunc) []toolsets.ServerTool { + return []toolsets.ServerTool{ + // Context tools + GetMe(t), + GetTeams(t), + GetTeamMembers(t), + + // Repository tools + SearchRepositories(t), + GetFileContents(t), + ListCommits(t), + SearchCode(t), + GetCommit(t), + ListBranches(t), + ListTags(t), + GetTag(t), + ListReleases(t), + GetLatestRelease(t), + GetReleaseByTag(t), + CreateOrUpdateFile(t), + CreateRepository(t), + ForkRepository(t), + CreateBranch(t), + PushFiles(t), + DeleteFile(t), + ListStarredRepositories(t), + StarRepository(t), + UnstarRepository(t), + + // Git tools + GetRepositoryTree(t), + + // Issue tools + IssueRead(t), + SearchIssues(t), + ListIssues(t), + ListIssueTypes(t), + IssueWrite(t), + AddIssueComment(t), + AssignCopilotToIssue(t), + SubIssueWrite(t), + + // User tools + SearchUsers(t), + + // Organization tools + SearchOrgs(t), + + // Pull request tools + PullRequestRead(t), + ListPullRequests(t), + SearchPullRequests(t), + MergePullRequest(t), + UpdatePullRequestBranch(t), + CreatePullRequest(t), + UpdatePullRequest(t), + RequestCopilotReview(t), + PullRequestReviewWrite(t), + AddCommentToPendingReview(t), + + // Code security tools + GetCodeScanningAlert(t), + ListCodeScanningAlerts(t), + + // Secret protection tools + GetSecretScanningAlert(t), + ListSecretScanningAlerts(t), + + // Dependabot tools + GetDependabotAlert(t), + ListDependabotAlerts(t), + + // Notification tools + ListNotifications(t), + GetNotificationDetails(t), + DismissNotification(t), + MarkAllNotificationsRead(t), + ManageNotificationSubscription(t), + ManageRepositoryNotificationSubscription(t), + + // Discussion tools + ListDiscussions(t), + GetDiscussion(t), + GetDiscussionComments(t), + ListDiscussionCategories(t), + + // Actions tools + ListWorkflows(t), + ListWorkflowRuns(t), + GetWorkflowRun(t), + GetWorkflowRunLogs(t), + ListWorkflowJobs(t), + GetJobLogs(t), + ListWorkflowRunArtifacts(t), + DownloadWorkflowRunArtifact(t), + GetWorkflowRunUsage(t), + RunWorkflow(t), + RerunWorkflowRun(t), + RerunFailedJobs(t), + CancelWorkflowRun(t), + DeleteWorkflowRunLogs(t), + + // Security advisories tools + ListGlobalSecurityAdvisories(t), + GetGlobalSecurityAdvisory(t), + ListRepositorySecurityAdvisories(t), + ListOrgRepositorySecurityAdvisories(t), + + // Gist tools + ListGists(t), + GetGist(t), + CreateGist(t), + UpdateGist(t), + + // Project tools + ListProjects(t), + GetProject(t), + ListProjectFields(t), + GetProjectField(t), + ListProjectItems(t), + GetProjectItem(t), + AddProjectItem(t), + DeleteProjectItem(t), + UpdateProjectItem(t), + + // Label tools + GetLabel(t), + ListLabels(t), + LabelWrite(t), + } } // ToBoolPtr converts a bool to a *bool pointer. @@ -447,23 +304,29 @@ func ToStringPtr(s string) *string { // GenerateToolsetsHelp generates the help text for the toolsets flag func GenerateToolsetsHelp() string { // Format default tools - defaultTools := strings.Join(GetDefaultToolsetIDs(), ", ") + defaultIDs := GetDefaultToolsetIDs() + defaultStrings := make([]string, len(defaultIDs)) + for i, id := range defaultIDs { + defaultStrings[i] = string(id) + } + defaultTools := strings.Join(defaultStrings, ", ") // Format available tools with line breaks for better readability - allTools := AvailableTools() + allToolsets := AvailableToolsets() var availableToolsLines []string const maxLineLength = 70 currentLine := "" - for i, tool := range allTools { + for i, toolset := range allToolsets { + id := string(toolset.ID) switch { case i == 0: - currentLine = tool.ID - case len(currentLine)+len(tool.ID)+2 <= maxLineLength: - currentLine += ", " + tool.ID + currentLine = id + case len(currentLine)+len(id)+2 <= maxLineLength: + currentLine += ", " + id default: availableToolsLines = append(availableToolsLines, currentLine) - currentLine = tool.ID + currentLine = id } } if currentLine != "" { @@ -491,7 +354,7 @@ func AddDefaultToolset(result []string) []string { seen := make(map[string]bool) for _, toolset := range result { seen[toolset] = true - if toolset == ToolsetMetadataDefault.ID { + if toolset == string(ToolsetMetadataDefault.ID) { hasDefault = true } } @@ -501,11 +364,11 @@ func AddDefaultToolset(result []string) []string { return result } - result = RemoveToolset(result, ToolsetMetadataDefault.ID) + result = RemoveToolset(result, string(ToolsetMetadataDefault.ID)) for _, defaultToolset := range GetDefaultToolsetIDs() { - if !seen[defaultToolset] { - result = append(result, defaultToolset) + if !seen[string(defaultToolset)] { + result = append(result, string(defaultToolset)) } } return result @@ -531,7 +394,7 @@ func CleanToolsets(enabledToolsets []string) ([]string, []string) { if !seen[trimmed] { seen[trimmed] = true result = append(result, trimmed) - if !validIDs[trimmed] { + if !validIDs[toolsets.ToolsetID(trimmed)] { invalid = append(invalid, trimmed) } } diff --git a/pkg/github/toolset_group.go b/pkg/github/toolset_group.go new file mode 100644 index 000000000..bca1f7ca4 --- /dev/null +++ b/pkg/github/toolset_group.go @@ -0,0 +1,20 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/translations" +) + +// NewToolsetGroup creates a ToolsetGroup with all available tools, resources, and prompts. +// Tools are self-describing with their toolset metadata embedded. +// The "default" keyword in WithToolsets will expand to GetDefaultToolsetIDs(). +func NewToolsetGroup(t translations.TranslationHelperFunc, getClient GetClientFn, getRawClient raw.GetRawClientFn) *toolsets.ToolsetGroup { + tsg := toolsets.NewToolsetGroup( + AllTools(t), + AllResources(t, getClient, getRawClient), + AllPrompts(t), + ) + tsg.SetDefaultToolsetIDs(GetDefaultToolsetIDs()) + return tsg +} diff --git a/pkg/github/workflow_prompts.go b/pkg/github/workflow_prompts.go index bc7c7581f..cf972020d 100644 --- a/pkg/github/workflow_prompts.go +++ b/pkg/github/workflow_prompts.go @@ -4,13 +4,16 @@ import ( "context" "fmt" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/modelcontextprotocol/go-sdk/mcp" ) // IssueToFixWorkflowPrompt provides a guided workflow for creating an issue and then generating a PR to fix it -func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Prompt, handler mcp.PromptHandler) { - return mcp.Prompt{ +func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) toolsets.ServerPrompt { + return toolsets.NewServerPrompt( + ToolsetMetadataIssues, + mcp.Prompt{ Name: "issue_to_fix_workflow", Description: t("PROMPT_ISSUE_TO_FIX_WORKFLOW_DESCRIPTION", "Create an issue for a problem and then generate a pull request to fix it"), Arguments: []*mcp.PromptArgument{ @@ -102,5 +105,6 @@ func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Pr return &mcp.GetPromptResult{ Messages: messages, }, nil - } + }, + ) } diff --git a/pkg/toolsets/server_tool.go b/pkg/toolsets/server_tool.go index 3e3e5d9f8..334492c74 100644 --- a/pkg/toolsets/server_tool.go +++ b/pkg/toolsets/server_tool.go @@ -14,17 +14,47 @@ import ( // should define their own typed dependencies struct and type-assert as needed. type HandlerFunc func(deps any) mcp.ToolHandler -// ServerTool represents an MCP tool with a handler generator function. +// ToolsetID is a unique identifier for a toolset. +// Using a distinct type provides compile-time type safety. +type ToolsetID string + +// ToolsetMetadata contains metadata about the toolset a tool belongs to. +type ToolsetMetadata struct { + // ID is the unique identifier for the toolset (e.g., "repos", "issues") + ID ToolsetID + // Description provides a human-readable description of the toolset + Description string +} + +// ServerTool represents an MCP tool with metadata and a handler generator function. // The tool definition is static, while the handler is generated on-demand // when the tool is registered with a server. +// Tools are now self-describing with their toolset membership and read-only status +// derived from the Tool.Annotations.ReadOnlyHint field. type ServerTool struct { // Tool is the MCP tool definition containing name, description, schema, etc. Tool mcp.Tool + // Toolset contains metadata about which toolset this tool belongs to. + Toolset ToolsetMetadata + // HandlerFunc generates the handler when given dependencies. // This allows tools to be passed around without handlers being set up, // and handlers are only created when needed. HandlerFunc HandlerFunc + + // FeatureFlagEnable specifies a feature flag that must be enabled for this tool + // to be available. If set and the flag is not enabled, the tool is omitted. + FeatureFlagEnable string + + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool + // to be omitted. Used to disable tools when a feature flag is on. + FeatureFlagDisable string +} + +// IsReadOnly returns true if this tool is marked as read-only via annotations. +func (st *ServerTool) IsReadOnly() bool { + return st.Tool.Annotations != nil && st.Tool.Annotations.ReadOnlyHint } // Handler returns a tool handler by calling HandlerFunc with the given dependencies. @@ -41,12 +71,13 @@ func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { s.AddTool(&st.Tool, handler) } -// NewServerTool creates a ServerTool from a tool definition and a typed handler function. +// NewServerTool creates a ServerTool from a tool definition, toolset metadata, and a typed handler function. // The handler function takes dependencies (as any) and returns a typed handler. // Callers should type-assert deps to their typed dependencies struct. -func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { +func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { return ServerTool{ - Tool: tool, + Tool: tool, + Toolset: toolset, HandlerFunc: func(deps any) mcp.ToolHandler { typedHandler := handlerFn(deps) return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -61,18 +92,19 @@ func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps any) mcp. } } -// NewServerToolFromHandler creates a ServerTool from a tool definition and a raw handler function. +// NewServerToolFromHandler creates a ServerTool from a tool definition, toolset metadata, and a raw handler function. // Use this when you have a handler that already conforms to mcp.ToolHandler. -func NewServerToolFromHandler(tool mcp.Tool, handlerFn func(deps any) mcp.ToolHandler) ServerTool { - return ServerTool{Tool: tool, HandlerFunc: handlerFn} +func NewServerToolFromHandler(tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandler) ServerTool { + return ServerTool{Tool: tool, Toolset: toolset, HandlerFunc: handlerFn} } -// NewServerToolLegacy creates a ServerTool from a tool definition and an already-bound typed handler. +// NewServerToolLegacy creates a ServerTool from a tool definition, toolset metadata, and an already-bound typed handler. // This is for backward compatibility during the refactor - the handler doesn't use dependencies. // Deprecated: Use NewServerTool instead for new code. -func NewServerToolLegacy[In any, Out any](tool mcp.Tool, handler mcp.ToolHandlerFor[In, Out]) ServerTool { +func NewServerToolLegacy[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandlerFor[In, Out]) ServerTool { return ServerTool{ - Tool: tool, + Tool: tool, + Toolset: toolset, HandlerFunc: func(_ any) mcp.ToolHandler { return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var arguments In @@ -86,12 +118,13 @@ func NewServerToolLegacy[In any, Out any](tool mcp.Tool, handler mcp.ToolHandler } } -// NewServerToolFromHandlerLegacy creates a ServerTool from a tool definition and an already-bound raw handler. +// NewServerToolFromHandlerLegacy creates a ServerTool from a tool definition, toolset metadata, and an already-bound raw handler. // This is for backward compatibility during the refactor - the handler doesn't use dependencies. // Deprecated: Use NewServerToolFromHandler instead for new code. -func NewServerToolFromHandlerLegacy(tool mcp.Tool, handler mcp.ToolHandler) ServerTool { +func NewServerToolFromHandlerLegacy(tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandler) ServerTool { return ServerTool{ - Tool: tool, + Tool: tool, + Toolset: toolset, HandlerFunc: func(_ any) mcp.ToolHandler { return handler }, diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 8502328d5..e42a9bcae 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -1,9 +1,11 @@ package toolsets import ( + "context" "fmt" "os" - "strings" + "slices" + "sort" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -30,276 +32,630 @@ func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { return &ToolsetDoesNotExistError{Name: name} } +// ToolDoesNotExistError is returned when a tool is not found. +type ToolDoesNotExistError struct { + Name string +} + +func (e *ToolDoesNotExistError) Error() string { + return fmt.Sprintf("tool %s does not exist", e.Name) +} + +// NewToolDoesNotExistError creates a new ToolDoesNotExistError. +func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { + return &ToolDoesNotExistError{Name: name} +} + // ServerTool is defined in server_tool.go +// ServerResourceTemplate pairs a resource template with its toolset metadata. type ServerResourceTemplate struct { Template mcp.ResourceTemplate Handler mcp.ResourceHandler -} - -func NewServerResourceTemplate(resourceTemplate mcp.ResourceTemplate, handler mcp.ResourceHandler) ServerResourceTemplate { + // Toolset identifies which toolset this resource belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this resource + // to be available. If set and the flag is not enabled, the resource is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource + // to be omitted. Used to disable resources when a feature flag is on. + FeatureFlagDisable string +} + +// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. +func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handler mcp.ResourceHandler) ServerResourceTemplate { return ServerResourceTemplate{ Template: resourceTemplate, Handler: handler, + Toolset: toolset, } } +// ServerPrompt pairs a prompt with its toolset metadata. type ServerPrompt struct { Prompt mcp.Prompt Handler mcp.PromptHandler -} - -func NewServerPrompt(prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { + // Toolset identifies which toolset this prompt belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt + // to be available. If set and the flag is not enabled, the prompt is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt + // to be omitted. Used to disable prompts when a feature flag is on. + FeatureFlagDisable string +} + +// NewServerPrompt creates a new ServerPrompt with toolset metadata. +func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { return ServerPrompt{ Prompt: prompt, Handler: handler, + Toolset: toolset, } } -// Toolset represents a collection of MCP functionality that can be enabled or disabled as a group. -type Toolset struct { - Name string - Description string - Enabled bool - readOnly bool - writeTools []ServerTool - readTools []ServerTool - // deps holds the dependencies for tool handlers (typed as any to avoid circular deps) - deps any - // resources are not tools, but the community seems to be moving towards namespaces as a broader concept - // and in order to have multiple servers running concurrently, we want to avoid overlapping resources too. +// ToolsetGroup holds a collection of tools, resources, and prompts. +// It supports immutable filtering operations that return new ToolsetGroups +// without modifying the original. This design allows for: +// - Building a full set of tools/resources/prompts once +// - Applying filters (read-only, feature flags, enabled toolsets) without mutation +// - Deterministic ordering for documentation generation +// - Lazy dependency injection only when registering with a server +type ToolsetGroup struct { + // tools holds all tools in this group + tools []ServerTool + // resourceTemplates holds all resource templates in this group resourceTemplates []ServerResourceTemplate - // prompts are also not tools but are namespaced similarly + // prompts holds all prompts in this group prompts []ServerPrompt -} - -func (t *Toolset) GetActiveTools() []ServerTool { - if t.Enabled { - if t.readOnly { - return t.readTools + // deprecatedAliases maps old tool names to new canonical names + deprecatedAliases map[string]string + // defaultToolsetIDs are the toolset IDs that "default" expands to + defaultToolsetIDs []ToolsetID + + // Filters - these control what's returned by Available* methods + // readOnly when true filters out write tools + readOnly bool + // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets + // when nil, all toolsets are enabled + enabledToolsets map[ToolsetID]bool + // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) + // These are additive - a tool is included if it matches toolset filters OR is in this set + additionalTools map[string]bool + // featureChecker when non-nil, checks if a feature flag is enabled. + // Takes context and flag name, returns (enabled, error). If error, log and treat as false. + // If checker is nil, all flag checks return false. + featureChecker FeatureFlagChecker +} + +// FeatureFlagChecker is a function that checks if a feature flag is enabled. +// The context can be used to extract actor/user information for flag evaluation. +// Returns (enabled, error). If error occurs, the caller should log and treat as false. +type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) + +// NewToolsetGroup creates a new ToolsetGroup from the provided tools, resources, and prompts. +// The group is created with no filters applied. +func NewToolsetGroup(tools []ServerTool, resources []ServerResourceTemplate, prompts []ServerPrompt) *ToolsetGroup { + return &ToolsetGroup{ + tools: tools, + resourceTemplates: resources, + prompts: prompts, + deprecatedAliases: make(map[string]string), + readOnly: false, + enabledToolsets: nil, + additionalTools: nil, + featureChecker: nil, + } +} + +// copy creates a shallow copy of the ToolsetGroup for immutable operations. +func (tg *ToolsetGroup) copy() *ToolsetGroup { + newTG := &ToolsetGroup{ + tools: tg.tools, // slices are shared (immutable) + resourceTemplates: tg.resourceTemplates, + prompts: tg.prompts, + deprecatedAliases: tg.deprecatedAliases, + defaultToolsetIDs: tg.defaultToolsetIDs, + readOnly: tg.readOnly, + featureChecker: tg.featureChecker, + } + + // Copy maps if they exist + if tg.enabledToolsets != nil { + newTG.enabledToolsets = make(map[ToolsetID]bool, len(tg.enabledToolsets)) + for k, v := range tg.enabledToolsets { + newTG.enabledToolsets[k] = v } - return append(t.readTools, t.writeTools...) - } - return nil -} - -func (t *Toolset) GetAvailableTools() []ServerTool { - if t.readOnly { - return t.readTools - } - return append(t.readTools, t.writeTools...) -} - -func (t *Toolset) RegisterTools(s *mcp.Server) { - if !t.Enabled { - return - } - for i := range t.readTools { - t.readTools[i].RegisterFunc(s, t.deps) } - if !t.readOnly { - for i := range t.writeTools { - t.writeTools[i].RegisterFunc(s, t.deps) + if tg.additionalTools != nil { + newTG.additionalTools = make(map[string]bool, len(tg.additionalTools)) + for k, v := range tg.additionalTools { + newTG.additionalTools[k] = v } } -} -// SetDependencies sets the dependencies for this toolset's tool handlers. -// The deps parameter is typed as `any` to avoid circular dependencies between packages. -func (t *Toolset) SetDependencies(deps any) *Toolset { - t.deps = deps - return t + return newTG } -func (t *Toolset) AddResourceTemplates(templates ...ServerResourceTemplate) *Toolset { - t.resourceTemplates = append(t.resourceTemplates, templates...) - return t +// WithReadOnly returns a new ToolsetGroup with read-only mode set. +// When true, write tools are filtered out from Available* methods. +func (tg *ToolsetGroup) WithReadOnly(readOnly bool) *ToolsetGroup { + newTG := tg.copy() + newTG.readOnly = readOnly + return newTG } -func (t *Toolset) AddPrompts(prompts ...ServerPrompt) *Toolset { - t.prompts = append(t.prompts, prompts...) - return t +// SetDefaultToolsetIDs configures which toolset IDs the "default" keyword expands to. +// This should be called before WithToolsets if you want "default" to be recognized. +func (tg *ToolsetGroup) SetDefaultToolsetIDs(ids []ToolsetID) *ToolsetGroup { + tg.defaultToolsetIDs = ids + return tg } -func (t *Toolset) GetActiveResourceTemplates() []ServerResourceTemplate { - if !t.Enabled { - return nil +// WithToolsets returns a new ToolsetGroup that only includes items from the specified toolsets. +// Special keywords: +// - "all": enables all toolsets +// - "default": expands to the default toolset IDs (set via SetDefaultToolsetIDs) +// +// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets +// (useful for dynamic toolsets mode where tools are enabled on demand). +func (tg *ToolsetGroup) WithToolsets(toolsetIDs []string) *ToolsetGroup { + newTG := tg.copy() + + // Check for "all" keyword - enables all toolsets + for _, id := range toolsetIDs { + if id == "all" { + newTG.enabledToolsets = nil + return newTG + } } - return t.resourceTemplates -} -func (t *Toolset) GetAvailableResourceTemplates() []ServerResourceTemplate { - return t.resourceTemplates -} + // nil means use defaults, empty slice means no toolsets + if toolsetIDs == nil { + toolsetIDs = []string{"default"} + } -func (t *Toolset) RegisterResourcesTemplates(s *mcp.Server) { - if !t.Enabled { - return + // Expand "default" keyword and collect other IDs + seen := make(map[ToolsetID]bool) + expanded := make([]ToolsetID, 0, len(toolsetIDs)) + for _, id := range toolsetIDs { + if id == "default" { + for _, defaultID := range tg.defaultToolsetIDs { + if !seen[defaultID] { + seen[defaultID] = true + expanded = append(expanded, defaultID) + } + } + } else { + tsID := ToolsetID(id) + if !seen[tsID] { + seen[tsID] = true + expanded = append(expanded, tsID) + } + } } - for _, resource := range t.resourceTemplates { - s.AddResourceTemplate(&resource.Template, resource.Handler) + + if len(expanded) == 0 { + newTG.enabledToolsets = make(map[ToolsetID]bool) + return newTG } + + newTG.enabledToolsets = make(map[ToolsetID]bool, len(expanded)) + for _, id := range expanded { + newTG.enabledToolsets[id] = true + } + return newTG } -func (t *Toolset) RegisterPrompts(s *mcp.Server) { - if !t.Enabled { - return +// WithTools returns a new ToolsetGroup with additional tools that bypass toolset filtering. +// These tools are additive - they will be included even if their toolset is not enabled. +// Read-only filtering still applies to these tools. +// Deprecated tool aliases are automatically resolved to their canonical names. +// Pass nil or empty slice to clear additional tools. +func (tg *ToolsetGroup) WithTools(toolNames []string) *ToolsetGroup { + newTG := tg.copy() + if len(toolNames) == 0 { + newTG.additionalTools = nil + return newTG } - for _, prompt := range t.prompts { - s.AddPrompt(&prompt.Prompt, prompt.Handler) + newTG.additionalTools = make(map[string]bool, len(toolNames)) + for _, name := range toolNames { + // Resolve deprecated aliases to canonical names + if canonical, isAlias := tg.deprecatedAliases[name]; isAlias { + newTG.additionalTools[canonical] = true + } else { + newTG.additionalTools[name] = true + } } -} + return newTG +} + +// WithFeatureChecker returns a new ToolsetGroup with a feature checker function. +// The checker receives a context (for actor extraction) and feature flag name, returns (enabled, error). +// If error occurs, it will be logged and treated as false. +// If checker is nil, all feature flag checks return false (items with FeatureFlagEnable are excluded, +// items with FeatureFlagDisable are included). +func (tg *ToolsetGroup) WithFeatureChecker(checker FeatureFlagChecker) *ToolsetGroup { + newTG := tg.copy() + newTG.featureChecker = checker + return newTG +} + +// MCP method constants for use with ForMCPRequest. +const ( + MCPMethodInitialize = "initialize" + MCPMethodToolsList = "tools/list" + MCPMethodToolsCall = "tools/call" + MCPMethodResourcesList = "resources/list" + MCPMethodResourcesRead = "resources/read" + MCPMethodResourcesTemplatesList = "resources/templates/list" + MCPMethodPromptsList = "prompts/list" + MCPMethodPromptsGet = "prompts/get" +) -func (t *Toolset) SetReadOnly() { - // Set the toolset to read-only - t.readOnly = true -} +// ForMCPRequest returns a ToolsetGroup optimized for a specific MCP request. +// This is designed for servers that create a new instance per request (like the remote server), +// allowing them to only register the items needed for that specific request rather than all ~90 tools. +// +// Parameters: +// - method: The MCP method being called (use MCP* constants) +// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) +// +// Returns a new ToolsetGroup containing only the items relevant to the request: +// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) +// - MCPMethodToolsList: All available tools (no resources/prompts) +// - MCPMethodToolsCall: Only the named tool +// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) +// - MCPMethodResourcesRead: Only the named resource template +// - MCPMethodPromptsList: All available prompts (no tools/resources) +// - MCPMethodPromptsGet: Only the named prompt +// - Unknown methods: Empty (no items registered) +// +// All existing filters (read-only, toolsets, etc.) still apply to the returned items. +func (tg *ToolsetGroup) ForMCPRequest(method string, itemName string) *ToolsetGroup { + result := tg.copy() + + switch method { + case MCPMethodInitialize: + // Capabilities only - no items need to be registered + // The server capabilities (tools, resources, prompts support) are set via ServerOptions + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + + case MCPMethodToolsList: + // All available tools, but no resources or prompts + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + + case MCPMethodToolsCall: + // Only the specific tool (if found), no resources or prompts + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + if itemName != "" { + result.tools = tg.filterToolsByName(itemName) + } -func (t *Toolset) AddWriteTools(tools ...ServerTool) *Toolset { - // Silently ignore if the toolset is read-only to avoid any breach of that contract - for _, tool := range tools { - if tool.Tool.Annotations.ReadOnlyHint { - panic(fmt.Sprintf("tool (%s) is incorrectly annotated as read-only", tool.Tool.Name)) + case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: + // All available resources, but no tools or prompts + result.tools = []ServerTool{} + result.prompts = []ServerPrompt{} + + case MCPMethodResourcesRead: + // Only the specific resource template, no tools or prompts + result.tools = []ServerTool{} + result.prompts = []ServerPrompt{} + if itemName != "" { + result.resourceTemplates = tg.filterResourcesByURI(itemName) } + + case MCPMethodPromptsList: + // All available prompts, but no tools or resources + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + + case MCPMethodPromptsGet: + // Only the specific prompt, no tools or resources + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + if itemName != "" { + result.prompts = tg.filterPromptsByName(itemName) + } + + default: + // Unknown method - register nothing + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} } - if !t.readOnly { - t.writeTools = append(t.writeTools, tools...) - } - return t + + return result } -func (t *Toolset) AddReadTools(tools ...ServerTool) *Toolset { - for _, tool := range tools { - if !tool.Tool.Annotations.ReadOnlyHint { - panic(fmt.Sprintf("tool (%s) must be annotated as read-only", tool.Tool.Name)) +// filterToolsByName returns tools matching the given name, checking deprecated aliases. +// Returns from the current tools slice (respects existing filter chain). +func (tg *ToolsetGroup) filterToolsByName(name string) []ServerTool { + // First check for exact match + for i := range tg.tools { + if tg.tools[i].Tool.Name == name { + return []ServerTool{tg.tools[i]} + } + } + // Check if name is a deprecated alias + if canonical, isAlias := tg.deprecatedAliases[name]; isAlias { + for i := range tg.tools { + if tg.tools[i].Tool.Name == canonical { + return []ServerTool{tg.tools[i]} + } } } - t.readTools = append(t.readTools, tools...) - return t + return []ServerTool{} } -type ToolsetGroup struct { - Toolsets map[string]*Toolset - deprecatedAliases map[string]string - everythingOn bool - readOnly bool +// filterResourcesByURI returns resource templates matching the given URI pattern. +func (tg *ToolsetGroup) filterResourcesByURI(uri string) []ServerResourceTemplate { + for i := range tg.resourceTemplates { + // Check if URI matches the template pattern (exact match on URITemplate string) + if tg.resourceTemplates[i].Template.URITemplate == uri { + return []ServerResourceTemplate{tg.resourceTemplates[i]} + } + } + return []ServerResourceTemplate{} } -func NewToolsetGroup(readOnly bool) *ToolsetGroup { - return &ToolsetGroup{ - Toolsets: make(map[string]*Toolset), - deprecatedAliases: make(map[string]string), - everythingOn: false, - readOnly: readOnly, +// filterPromptsByName returns prompts matching the given name. +func (tg *ToolsetGroup) filterPromptsByName(name string) []ServerPrompt { + for i := range tg.prompts { + if tg.prompts[i].Prompt.Name == name { + return []ServerPrompt{tg.prompts[i]} + } } + return []ServerPrompt{} } -func (tg *ToolsetGroup) AddDeprecatedToolAliases(aliases map[string]string) { +// AddDeprecatedToolAliases adds mappings from old tool names to new canonical names. +func (tg *ToolsetGroup) AddDeprecatedToolAliases(aliases map[string]string) *ToolsetGroup { for oldName, newName := range aliases { tg.deprecatedAliases[oldName] = newName } + return tg } -func (tg *ToolsetGroup) AddToolset(ts *Toolset) { - if tg.readOnly { - ts.SetReadOnly() +// isToolsetEnabled checks if a toolset is enabled based on current filters. +func (tg *ToolsetGroup) isToolsetEnabled(toolsetID ToolsetID) bool { + // Check enabled toolsets filter + if tg.enabledToolsets != nil { + return tg.enabledToolsets[toolsetID] } - tg.Toolsets[ts.Name] = ts + return true } -func NewToolset(name string, description string) *Toolset { - return &Toolset{ - Name: name, - Description: description, - Enabled: false, - readOnly: false, +// checkFeatureFlag checks a feature flag using the feature checker. +// Returns false if checker is nil or returns an error (errors are logged). +func (tg *ToolsetGroup) checkFeatureFlag(ctx context.Context, flagName string) bool { + if tg.featureChecker == nil || flagName == "" { + return false } + enabled, err := tg.featureChecker(ctx, flagName) + if err != nil { + fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + return false + } + return enabled } -func (tg *ToolsetGroup) IsEnabled(name string) bool { - // If everythingOn is true, all features are enabled - if tg.everythingOn { - return true +// isFeatureFlagAllowed checks if an item passes feature flag filtering. +// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled +// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled +func (tg *ToolsetGroup) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { + // Check enable flag - item requires this flag to be on + if enableFlag != "" && !tg.checkFeatureFlag(ctx, enableFlag) { + return false } + // Check disable flag - item is excluded if this flag is on + if disableFlag != "" && tg.checkFeatureFlag(ctx, disableFlag) { + return false + } + return true +} - feature, exists := tg.Toolsets[name] - if !exists { +// isToolEnabled checks if a specific tool is enabled based on current filters. +func (tg *ToolsetGroup) isToolEnabled(ctx context.Context, tool *ServerTool) bool { + // Check read-only filter first (applies to all tools) + if tg.readOnly && !tool.IsReadOnly() { return false } - return feature.Enabled + // Check feature flags + if !tg.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { + return false + } + // Check if tool is in additionalTools (bypasses toolset filter) + if tg.additionalTools != nil && tg.additionalTools[tool.Tool.Name] { + return true + } + // Check toolset filter + if !tg.isToolsetEnabled(tool.Toolset.ID) { + return false + } + return true } -type EnableToolsetsOptions struct { - ErrorOnUnknown bool +// AvailableTools returns the tools that pass all current filters, +// sorted deterministically by toolset ID, then tool name. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) AvailableTools(ctx context.Context) []ServerTool { + var result []ServerTool + for i := range tg.tools { + tool := &tg.tools[i] + if tg.isToolEnabled(ctx, tool) { + result = append(result, *tool) + } + } + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result } -func (tg *ToolsetGroup) EnableToolsets(names []string, options *EnableToolsetsOptions) error { - if options == nil { - options = &EnableToolsetsOptions{ - ErrorOnUnknown: false, +// AvailableResourceTemplates returns resource templates that pass all current filters, +// sorted deterministically by toolset ID, then template name. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { + var result []ServerResourceTemplate + for i := range tg.resourceTemplates { + res := &tg.resourceTemplates[i] + // Check feature flags + if !tg.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { + continue + } + if tg.isToolsetEnabled(res.Toolset.ID) { + result = append(result, *res) } } - // Special case for "all" - for _, name := range names { - if name == "all" { - tg.everythingOn = true - break + // Sort deterministically: by toolset ID, then by template name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Template.Name < result[j].Template.Name + }) + + return result +} + +// AvailablePrompts returns prompts that pass all current filters, +// sorted deterministically by toolset ID, then prompt name. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) AvailablePrompts(ctx context.Context) []ServerPrompt { + var result []ServerPrompt + for i := range tg.prompts { + prompt := &tg.prompts[i] + // Check feature flags + if !tg.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { + continue } - err := tg.EnableToolset(name) - if err != nil && options.ErrorOnUnknown { - return err + if tg.isToolsetEnabled(prompt.Toolset.ID) { + result = append(result, *prompt) } } - // Do this after to ensure all toolsets are enabled if "all" is present anywhere in list - if tg.everythingOn { - for name := range tg.Toolsets { - err := tg.EnableToolset(name) - if err != nil && options.ErrorOnUnknown { - return err - } + + // Sort deterministically: by toolset ID, then by prompt name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID } - return nil + return result[i].Prompt.Name < result[j].Prompt.Name + }) + + return result +} + +// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. +func (tg *ToolsetGroup) ToolsetIDs() []ToolsetID { + seen := make(map[ToolsetID]bool) + for i := range tg.tools { + seen[tg.tools[i].Toolset.ID] = true + } + for i := range tg.resourceTemplates { + seen[tg.resourceTemplates[i].Toolset.ID] = true + } + for i := range tg.prompts { + seen[tg.prompts[i].Toolset.ID] = true } - return nil + + ids := make([]ToolsetID, 0, len(seen)) + for id := range seen { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids } -func (tg *ToolsetGroup) EnableToolset(name string) error { - toolset, exists := tg.Toolsets[name] - if !exists { - return NewToolsetDoesNotExistError(name) +// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. +func (tg *ToolsetGroup) ToolsetDescriptions() map[ToolsetID]string { + descriptions := make(map[ToolsetID]string) + for i := range tg.tools { + t := &tg.tools[i] + if t.Toolset.Description != "" { + descriptions[t.Toolset.ID] = t.Toolset.Description + } + } + for i := range tg.resourceTemplates { + r := &tg.resourceTemplates[i] + if r.Toolset.Description != "" { + descriptions[r.Toolset.ID] = r.Toolset.Description + } } - toolset.Enabled = true - tg.Toolsets[name] = toolset - return nil + for i := range tg.prompts { + p := &tg.prompts[i] + if p.Toolset.Description != "" { + descriptions[p.Toolset.ID] = p.Toolset.Description + } + } + return descriptions } -func (tg *ToolsetGroup) RegisterAll(s *mcp.Server) { - for _, toolset := range tg.Toolsets { - toolset.RegisterTools(s) - toolset.RegisterResourcesTemplates(s) - toolset.RegisterPrompts(s) +// ToolsForToolset returns all tools belonging to a specific toolset. +// This method bypasses the toolset enabled filter (for dynamic toolset registration), +// but still respects the read-only filter. +func (tg *ToolsetGroup) ToolsForToolset(toolsetID ToolsetID) []ServerTool { + var result []ServerTool + for i := range tg.tools { + tool := &tg.tools[i] + // Only check read-only filter, not toolset enabled filter + if tool.Toolset.ID == toolsetID { + if tg.readOnly && !tool.IsReadOnly() { + continue + } + result = append(result, *tool) + } } + + // Sort by tool name for deterministic order + sort.Slice(result, func(i, j int) bool { + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result } -func (tg *ToolsetGroup) GetToolset(name string) (*Toolset, error) { - toolset, exists := tg.Toolsets[name] - if !exists { - return nil, NewToolsetDoesNotExistError(name) +// RegisterTools registers all available tools with the server using the provided dependencies. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { + for _, tool := range tg.AvailableTools(ctx) { + tool.RegisterFunc(s, deps) } - return toolset, nil } -type ToolDoesNotExistError struct { - Name string +// RegisterResourceTemplates registers all available resource templates with the server. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) RegisterResourceTemplates(ctx context.Context, s *mcp.Server) { + for _, res := range tg.AvailableResourceTemplates(ctx) { + s.AddResourceTemplate(&res.Template, res.Handler) + } } -func (e *ToolDoesNotExistError) Error() string { - return fmt.Sprintf("tool %s does not exist", e.Name) +// RegisterPrompts registers all available prompts with the server. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) RegisterPrompts(ctx context.Context, s *mcp.Server) { + for _, prompt := range tg.AvailablePrompts(ctx) { + s.AddPrompt(&prompt.Prompt, prompt.Handler) + } } -func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { - return &ToolDoesNotExistError{Name: name} +// RegisterAll registers all available tools, resources, and prompts with the server. +// The context is used for feature flag evaluation. +func (tg *ToolsetGroup) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { + tg.RegisterTools(ctx, s, deps) + tg.RegisterResourceTemplates(ctx, s) + tg.RegisterPrompts(ctx, s) } // ResolveToolAliases resolves deprecated tool aliases to their canonical names. @@ -322,51 +678,82 @@ func (tg *ToolsetGroup) ResolveToolAliases(toolNames []string) (resolved []strin return resolved, aliasesUsed } -// FindToolByName searches all toolsets (enabled or disabled) for a tool by name. -// Returns the tool, its parent toolset name, and an error if not found. -func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, string, error) { - for toolsetName, toolset := range tg.Toolsets { - // Check read tools - for _, tool := range toolset.readTools { - if tool.Tool.Name == toolName { - return &tool, toolsetName, nil - } - } - // Check write tools - for _, tool := range toolset.writeTools { - if tool.Tool.Name == toolName { - return &tool, toolsetName, nil - } +// FindToolByName searches all tools for one matching the given name. +// Returns the tool, its toolset ID, and an error if not found. +// This searches ALL tools regardless of filters. +func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { + for i := range tg.tools { + tool := &tg.tools[i] + if tool.Tool.Name == toolName { + return tool, tool.Toolset.ID, nil } } return nil, "", NewToolDoesNotExistError(toolName) } -// RegisterSpecificTools registers only the specified tools. -// Respects read-only mode (skips write tools if readOnly=true). -// Returns error if any tool is not found. -func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool, deps any) error { - var skippedTools []string - for _, toolName := range toolNames { - tool, _, err := tg.FindToolByName(toolName) - if err != nil { - return fmt.Errorf("tool %s not found: %w", toolName, err) +// HasToolset checks if any tool/resource/prompt belongs to the given toolset. +func (tg *ToolsetGroup) HasToolset(toolsetID ToolsetID) bool { + for i := range tg.tools { + if tg.tools[i].Toolset.ID == toolsetID { + return true } - - if !tool.Tool.Annotations.ReadOnlyHint && readOnly { - // Skip write tools in read-only mode - skippedTools = append(skippedTools, toolName) - continue + } + for i := range tg.resourceTemplates { + if tg.resourceTemplates[i].Toolset.ID == toolsetID { + return true + } + } + for i := range tg.prompts { + if tg.prompts[i].Toolset.ID == toolsetID { + return true } + } + return false +} - // Register the tool - tool.RegisterFunc(s, deps) +// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. +// Returns all toolset IDs if no filter is set. +func (tg *ToolsetGroup) EnabledToolsetIDs() []ToolsetID { + if tg.enabledToolsets == nil { + return tg.ToolsetIDs() + } + + ids := make([]ToolsetID, 0, len(tg.enabledToolsets)) + for id := range tg.enabledToolsets { + if tg.HasToolset(id) { + ids = append(ids, id) + } } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} - // Log skipped write tools if any - if len(skippedTools) > 0 { - fmt.Fprintf(os.Stderr, "Write tools skipped due to read-only mode: %s\n", strings.Join(skippedTools, ", ")) +// IsToolsetEnabled checks if a toolset is currently enabled based on filters. +func (tg *ToolsetGroup) IsToolsetEnabled(toolsetID ToolsetID) bool { + return tg.isToolsetEnabled(toolsetID) +} + +// EnableToolset marks a toolset as enabled in this group. +// This is used by dynamic toolset management to track which toolsets have been enabled. +func (tg *ToolsetGroup) EnableToolset(toolsetID ToolsetID) { + if tg.enabledToolsets == nil { + // nil means all enabled, so nothing to do + return } + tg.enabledToolsets[toolsetID] = true +} + +// AllTools returns all tools without any filtering, sorted deterministically. +func (tg *ToolsetGroup) AllTools() []ServerTool { + result := slices.Clone(tg.tools) + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) - return nil + return result } diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index 66be5cba2..7bec55848 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -3,14 +3,22 @@ package toolsets import ( "context" "encoding/json" - "errors" + "fmt" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" ) +// testToolsetMetadata returns a ToolsetMetadata for testing +func testToolsetMetadata(id string) ToolsetMetadata { + return ToolsetMetadata{ + ID: ToolsetID(id), + Description: "Test toolset: " + id, + } +} + // mockTool creates a minimal ServerTool for testing -func mockTool(name string, readOnly bool) ServerTool { +func mockTool(name string, toolsetID string, readOnly bool) ServerTool { return NewServerToolFromHandler( mcp.Tool{ Name: name, @@ -19,6 +27,7 @@ func mockTool(name string, readOnly bool) ServerTool { }, InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, + testToolsetMetadata(toolsetID), func(_ any) mcp.ToolHandler { return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil @@ -27,382 +36,1000 @@ func mockTool(name string, readOnly bool) ServerTool { ) } -func TestNewToolsetGroupIsEmptyWithoutEverythingOn(t *testing.T) { - tsg := NewToolsetGroup(false) - if len(tsg.Toolsets) != 0 { - t.Fatalf("Expected Toolsets map to be empty, got %d items", len(tsg.Toolsets)) +func TestNewToolsetGroupEmpty(t *testing.T) { + tsg := NewToolsetGroup(nil, nil, nil) + if len(tsg.tools) != 0 { + t.Fatalf("Expected tools to be empty, got %d items", len(tsg.tools)) } - if tsg.everythingOn { - t.Fatal("Expected everythingOn to be initialized as false") + if len(tsg.resourceTemplates) != 0 { + t.Fatalf("Expected resourceTemplates to be empty, got %d items", len(tsg.resourceTemplates)) + } + if len(tsg.prompts) != 0 { + t.Fatalf("Expected prompts to be empty, got %d items", len(tsg.prompts)) } } -func TestAddToolset(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestNewToolsetGroupWithTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } - // Test adding a toolset - toolset := NewToolset("test-toolset", "A test toolset") - toolset.Enabled = true - tsg.AddToolset(toolset) + tsg := NewToolsetGroup(tools, nil, nil) - // Verify toolset was added correctly - if len(tsg.Toolsets) != 1 { - t.Errorf("Expected 1 toolset, got %d", len(tsg.Toolsets)) + if len(tsg.tools) != 3 { + t.Errorf("Expected 3 tools, got %d", len(tsg.tools)) } +} - toolset, exists := tsg.Toolsets["test-toolset"] - if !exists { - t.Fatal("Feature was not added to the map") +func TestAvailableTools_NoFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool_b", "toolset1", true), + mockTool("tool_a", "toolset1", false), + mockTool("tool_c", "toolset2", true), } - if toolset.Name != "test-toolset" { - t.Errorf("Expected toolset name to be 'test-toolset', got '%s'", toolset.Name) + tsg := NewToolsetGroup(tools, nil, nil) + available := tsg.AvailableTools(context.Background()) + + if len(available) != 3 { + t.Fatalf("Expected 3 available tools, got %d", len(available)) } - if toolset.Description != "A test toolset" { - t.Errorf("Expected toolset description to be 'A test toolset', got '%s'", toolset.Description) + // Verify deterministic sorting: by toolset ID, then tool name + expectedOrder := []string{"tool_a", "tool_b", "tool_c"} + for i, tool := range available { + if tool.Tool.Name != expectedOrder[i] { + t.Errorf("Tool at index %d: expected %s, got %s", i, expectedOrder[i], tool.Tool.Name) + } } +} - if !toolset.Enabled { - t.Error("Expected toolset to be enabled") +func TestWithReadOnly(t *testing.T) { + tools := []ServerTool{ + mockTool("read_tool", "toolset1", true), + mockTool("write_tool", "toolset1", false), } - // Test adding another toolset - anotherToolset := NewToolset("another-toolset", "Another test toolset") - tsg.AddToolset(anotherToolset) + tsg := NewToolsetGroup(tools, nil, nil) - if len(tsg.Toolsets) != 2 { - t.Errorf("Expected 2 toolsets, got %d", len(tsg.Toolsets)) + // Original should have both tools + allTools := tsg.AvailableTools(context.Background()) + if len(allTools) != 2 { + t.Fatalf("Expected 2 tools in original, got %d", len(allTools)) } - // Test overriding existing toolset - updatedToolset := NewToolset("test-toolset", "Updated description") - tsg.AddToolset(updatedToolset) - - toolset = tsg.Toolsets["test-toolset"] - if toolset.Description != "Updated description" { - t.Errorf("Expected toolset description to be updated to 'Updated description', got '%s'", toolset.Description) + // Read-only should filter out write tools + readOnlyTsg := tsg.WithReadOnly(true) + readOnlyTools := readOnlyTsg.AvailableTools(context.Background()) + if len(readOnlyTools) != 1 { + t.Fatalf("Expected 1 tool in read-only, got %d", len(readOnlyTools)) + } + if readOnlyTools[0].Tool.Name != "read_tool" { + t.Errorf("Expected read_tool, got %s", readOnlyTools[0].Tool.Name) } - if toolset.Enabled { - t.Error("Expected toolset to be disabled after update") + // Original should still have both (immutability test) + allTools = tsg.AvailableTools(context.Background()) + if len(allTools) != 2 { + t.Fatalf("Original was mutated! Expected 2 tools, got %d", len(allTools)) } } -func TestIsEnabled(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestWithToolsets(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + mockTool("tool3", "toolset3", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + + // Filter to specific toolsets + filteredTsg := tsg.WithToolsets([]string{"toolset1", "toolset3"}) + filteredTools := filteredTsg.AvailableTools(context.Background()) - // Test with non-existent toolset - if tsg.IsEnabled("non-existent") { - t.Error("Expected IsEnabled to return false for non-existent toolset") + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) } - // Test with disabled toolset - disabledToolset := NewToolset("disabled-toolset", "A disabled toolset") - tsg.AddToolset(disabledToolset) - if tsg.IsEnabled("disabled-toolset") { - t.Error("Expected IsEnabled to return false for disabled toolset") + // Verify correct tools are included + toolNames := make(map[string]bool) + for _, tool := range filteredTools { + toolNames[tool.Tool.Name] = true + } + if !toolNames["tool1"] || !toolNames["tool3"] { + t.Errorf("Expected tool1 and tool3, got %v", toolNames) } - // Test with enabled toolset - enabledToolset := NewToolset("enabled-toolset", "An enabled toolset") - enabledToolset.Enabled = true - tsg.AddToolset(enabledToolset) - if !tsg.IsEnabled("enabled-toolset") { - t.Error("Expected IsEnabled to return true for enabled toolset") + // Original should still have all 3 (immutability test) + allTools := tsg.AvailableTools(context.Background()) + if len(allTools) != 3 { + t.Fatalf("Original was mutated! Expected 3 tools, got %d", len(allTools)) } } -func TestEnableFeature(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestWithTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset2", true), + } - // Test enabling non-existent toolset - err := tsg.EnableToolset("non-existent") - if err == nil { - t.Error("Expected error when enabling non-existent toolset") + tsg := NewToolsetGroup(tools, nil, nil) + + // WithTools adds additional tools that bypass toolset filtering + // When combined with WithToolsets([]), only the additional tools should be available + filteredTsg := tsg.WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}) + filteredTools := filteredTsg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) + } + + toolNames := make(map[string]bool) + for _, tool := range filteredTools { + toolNames[tool.Tool.Name] = true + } + if !toolNames["tool1"] || !toolNames["tool3"] { + t.Errorf("Expected tool1 and tool3, got %v", toolNames) + } +} + +func TestChainedFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("read1", "toolset1", true), + mockTool("write1", "toolset1", false), + mockTool("read2", "toolset2", true), + mockTool("write2", "toolset2", false), } - // Test enabling toolset - testToolset := NewToolset("test-toolset", "A test toolset") - tsg.AddToolset(testToolset) + tsg := NewToolsetGroup(tools, nil, nil) - if tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be disabled initially") + // Chain read-only and toolset filter + filtered := tsg.WithReadOnly(true).WithToolsets([]string{"toolset1"}) + result := filtered.AvailableTools(context.Background()) + + if len(result) != 1 { + t.Fatalf("Expected 1 tool after chained filters, got %d", len(result)) + } + if result[0].Tool.Name != "read1" { + t.Errorf("Expected read1, got %s", result[0].Tool.Name) } +} - err = tsg.EnableToolset("test-toolset") - if err != nil { - t.Errorf("Expected no error when enabling toolset, got: %v", err) +func TestToolsetIDs(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset_b", true), + mockTool("tool2", "toolset_a", true), + mockTool("tool3", "toolset_b", true), // duplicate toolset } - if !tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be enabled after EnableFeature call") + tsg := NewToolsetGroup(tools, nil, nil) + ids := tsg.ToolsetIDs() + + if len(ids) != 2 { + t.Fatalf("Expected 2 unique toolset IDs, got %d", len(ids)) } - // Test enabling already enabled toolset - err = tsg.EnableToolset("test-toolset") - if err != nil { - t.Errorf("Expected no error when enabling already enabled toolset, got: %v", err) + // Should be sorted + if ids[0] != "toolset_a" || ids[1] != "toolset_b" { + t.Errorf("Expected sorted IDs [toolset_a, toolset_b], got %v", ids) } } -func TestEnableToolsets(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestToolsetDescriptions(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } - // Prepare toolsets - toolset1 := NewToolset("toolset1", "Feature 1") - toolset2 := NewToolset("toolset2", "Feature 2") - tsg.AddToolset(toolset1) - tsg.AddToolset(toolset2) + tsg := NewToolsetGroup(tools, nil, nil) + descriptions := tsg.ToolsetDescriptions() - // Test enabling multiple toolsets - err := tsg.EnableToolsets([]string{"toolset1", "toolset2"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling toolsets, got: %v", err) + if len(descriptions) != 2 { + t.Fatalf("Expected 2 descriptions, got %d", len(descriptions)) } - if !tsg.IsEnabled("toolset1") { - t.Error("Expected toolset1 to be enabled") + if descriptions["toolset1"] != "Test toolset: toolset1" { + t.Errorf("Wrong description for toolset1: %s", descriptions["toolset1"]) } +} - if !tsg.IsEnabled("toolset2") { - t.Error("Expected toolset2 to be enabled") +func TestToolsForToolset(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset2", true), } - // Test with non-existent toolset in the list - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, nil) - if err != nil { - t.Errorf("Expected no error when ignoring unknown toolsets, got: %v", err) + tsg := NewToolsetGroup(tools, nil, nil) + toolset1Tools := tsg.ToolsForToolset("toolset1") + + if len(toolset1Tools) != 2 { + t.Fatalf("Expected 2 tools for toolset1, got %d", len(toolset1Tools)) } +} - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, &EnableToolsetsOptions{ - ErrorOnUnknown: false, +func TestAddDeprecatedToolAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("new_name", "toolset1", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + tsg.AddDeprecatedToolAliases(map[string]string{ + "old_name": "new_name", + "get_issue": "issue_read", }) + + if len(tsg.deprecatedAliases) != 2 { + t.Errorf("expected 2 aliases, got %d", len(tsg.deprecatedAliases)) + } + if tsg.deprecatedAliases["old_name"] != "new_name" { + t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsg.deprecatedAliases["old_name"]) + } +} + +func TestResolveToolAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + mockTool("some_tool", "toolset1", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + tsg.AddDeprecatedToolAliases(map[string]string{ + "get_issue": "issue_read", + }) + + // Test resolving a mix of aliases and canonical names + input := []string{"get_issue", "some_tool"} + resolved, aliasesUsed := tsg.ResolveToolAliases(input) + + if len(resolved) != 2 { + t.Fatalf("expected 2 resolved names, got %d", len(resolved)) + } + if resolved[0] != "issue_read" { + t.Errorf("expected 'issue_read', got '%s'", resolved[0]) + } + if resolved[1] != "some_tool" { + t.Errorf("expected 'some_tool' (unchanged), got '%s'", resolved[1]) + } + + if len(aliasesUsed) != 1 { + t.Fatalf("expected 1 alias used, got %d", len(aliasesUsed)) + } + if aliasesUsed["get_issue"] != "issue_read" { + t.Errorf("expected aliasesUsed['get_issue'] = 'issue_read', got '%s'", aliasesUsed["get_issue"]) + } +} + +func TestFindToolByName(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + + // Find by name + tool, toolsetID, err := tsg.FindToolByName("issue_read") if err != nil { - t.Errorf("Expected no error when ignoring unknown toolsets, got: %v", err) + t.Fatalf("expected no error, got %v", err) + } + if tool.Tool.Name != "issue_read" { + t.Errorf("expected tool name 'issue_read', got '%s'", tool.Tool.Name) + } + if toolsetID != "toolset1" { + t.Errorf("expected toolset ID 'toolset1', got '%s'", toolsetID) } - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, &EnableToolsetsOptions{ErrorOnUnknown: true}) + // Non-existent tool + _, _, err = tsg.FindToolByName("nonexistent") if err == nil { - t.Error("Expected error when enabling list with non-existent toolset") + t.Error("expected error for non-existent tool") } - if !errors.Is(err, NewToolsetDoesNotExistError("non-existent")) { - t.Errorf("Expected ToolsetDoesNotExistError when enabling non-existent toolset, got: %v", err) +} + +func TestWithToolsAdditive(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + mockTool("issue_write", "toolset1", false), + mockTool("repo_read", "toolset2", true), } - // Test with empty list - err = tsg.EnableToolsets([]string{}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error with empty toolset list, got: %v", err) + tsg := NewToolsetGroup(tools, nil, nil) + + // Test WithTools bypasses toolset filtering + // Enable only toolset2, but add issue_read as additional tool + filtered := tsg.WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}) + + available := filtered.AvailableTools(context.Background()) + if len(available) != 2 { + t.Errorf("expected 2 tools (repo_read from toolset + issue_read additional), got %d", len(available)) } - // Test enabling everything through EnableToolsets - tsg = NewToolsetGroup(false) - err = tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) + // Verify both tools are present + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + if !toolNames["issue_read"] { + t.Error("expected issue_read to be included as additional tool") + } + if !toolNames["repo_read"] { + t.Error("expected repo_read to be included from toolset2") + } + + // Test WithTools respects read-only mode + readOnlyFiltered := tsg.WithReadOnly(true).WithTools([]string{"issue_write"}) + available = readOnlyFiltered.AvailableTools(context.Background()) + + // issue_write should be excluded because read-only applies to additional tools too + for _, tool := range available { + if tool.Tool.Name == "issue_write" { + t.Error("expected issue_write to be excluded in read-only mode") + } } - if !tsg.everythingOn { - t.Error("Expected everythingOn to be true after enabling 'all' via EnableToolsets") + // Test WithTools with non-existent tool (should not error, just won't match anything) + nonexistent := tsg.WithToolsets([]string{}).WithTools([]string{"nonexistent"}) + available = nonexistent.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("expected 0 tools for non-existent additional tool, got %d", len(available)) } } -func TestEnableEverything(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestWithToolsResolvesAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + tsg.AddDeprecatedToolAliases(map[string]string{ + "get_issue": "issue_read", + }) - // Add a disabled toolset - testToolset := NewToolset("test-toolset", "A test toolset") - tsg.AddToolset(testToolset) + // Using deprecated alias should resolve to canonical name + filtered := tsg.WithToolsets([]string{}).WithTools([]string{"get_issue"}) + available := filtered.AvailableTools(context.Background()) - // Verify it's disabled - if tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be disabled initially") + if len(available) != 1 { + t.Errorf("expected 1 tool, got %d", len(available)) } + if available[0].Tool.Name != "issue_read" { + t.Errorf("expected issue_read, got %s", available[0].Tool.Name) + } +} - // Enable "all" - err := tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) +func TestHasToolset(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), } - // Verify everythingOn was set - if !tsg.everythingOn { - t.Error("Expected everythingOn to be true after enabling 'all'") + tsg := NewToolsetGroup(tools, nil, nil) + + if !tsg.HasToolset("toolset1") { + t.Error("expected HasToolset to return true for existing toolset") + } + if tsg.HasToolset("nonexistent") { + t.Error("expected HasToolset to return false for non-existent toolset") } +} - // Verify the previously disabled toolset is now enabled - if !tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be enabled when everythingOn is true") +func TestEnabledToolsetIDs(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), } - // Verify a non-existent toolset is also enabled - if !tsg.IsEnabled("non-existent") { - t.Error("Expected non-existent toolset to be enabled when everythingOn is true") + tsg := NewToolsetGroup(tools, nil, nil) + + // Without filter, all toolsets are enabled + ids := tsg.EnabledToolsetIDs() + if len(ids) != 2 { + t.Fatalf("Expected 2 enabled toolset IDs, got %d", len(ids)) + } + + // With filter + filtered := tsg.WithToolsets([]string{"toolset1"}) + filteredIDs := filtered.EnabledToolsetIDs() + if len(filteredIDs) != 1 { + t.Fatalf("Expected 1 enabled toolset ID, got %d", len(filteredIDs)) + } + if filteredIDs[0] != "toolset1" { + t.Errorf("Expected toolset1, got %s", filteredIDs[0]) } } -func TestIsEnabledWithEverythingOn(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestAllTools(t *testing.T) { + tools := []ServerTool{ + mockTool("read_tool", "toolset1", true), + mockTool("write_tool", "toolset1", false), + } - // Enable "all" - err := tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) + tsg := NewToolsetGroup(tools, nil, nil) + + // Even with read-only filter, AllTools returns everything + readOnlyTsg := tsg.WithReadOnly(true) + + allTools := readOnlyTsg.AllTools() + if len(allTools) != 2 { + t.Fatalf("Expected 2 tools from AllTools, got %d", len(allTools)) } - // Test that any toolset name returns true with IsEnabled - if !tsg.IsEnabled("some-toolset") { - t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") + // But AvailableTools respects the filter + availableTools := readOnlyTsg.AvailableTools(context.Background()) + if len(availableTools) != 1 { + t.Fatalf("Expected 1 tool from AvailableTools, got %d", len(availableTools)) } +} + +func TestServerToolIsReadOnly(t *testing.T) { + readTool := mockTool("read_tool", "toolset1", true) + writeTool := mockTool("write_tool", "toolset1", false) - if !tsg.IsEnabled("another-toolset") { - t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") + if !readTool.IsReadOnly() { + t.Error("Expected read tool to be read-only") + } + if writeTool.IsReadOnly() { + t.Error("Expected write tool to not be read-only") } } -func TestToolsetGroup_GetToolset(t *testing.T) { - tsg := NewToolsetGroup(false) - toolset := NewToolset("my-toolset", "desc") - tsg.AddToolset(toolset) +// mockResource creates a minimal ServerResourceTemplate for testing +func mockResource(name string, toolsetID string, uriTemplate string) ServerResourceTemplate { + return NewServerResourceTemplate( + testToolsetMetadata(toolsetID), + mcp.ResourceTemplate{ + Name: name, + URITemplate: uriTemplate, + }, + func(_ context.Context, _ *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return nil, nil + }, + ) +} - // Should find the toolset - got, err := tsg.GetToolset("my-toolset") - if err != nil { - t.Fatalf("expected no error, got %v", err) +// mockPrompt creates a minimal ServerPrompt for testing +func mockPrompt(name string, toolsetID string) ServerPrompt { + return NewServerPrompt( + testToolsetMetadata(toolsetID), + mcp.Prompt{Name: name}, + func(_ context.Context, _ *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return nil, nil + }, + ) +} + +func TestForMCPRequest_Initialize(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", false), } - if got != toolset { - t.Errorf("expected to get the same toolset instance") + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), } - // Should not find a non-existent toolset - _, err = tsg.GetToolset("does-not-exist") - if err == nil { - t.Error("expected error for missing toolset, got nil") + tsg := NewToolsetGroup(tools, resources, prompts) + filtered := tsg.ForMCPRequest(MCPMethodInitialize, "") + + // Initialize should return empty - capabilities come from ServerOptions + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for initialize, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for initialize, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) } - if !errors.Is(err, NewToolsetDoesNotExistError("does-not-exist")) { - t.Errorf("expected error to be ToolsetDoesNotExistError, got %v", err) + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for initialize, got %d", len(filtered.AvailablePrompts(context.Background()))) } } -func TestAddDeprecatedToolAliases(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestForMCPRequest_ToolsList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } - // Test adding aliases - tsg.AddDeprecatedToolAliases(map[string]string{ - "old_name": "new_name", - "get_issue": "issue_read", - "create_pr": "pull_request_create", - }) + tsg := NewToolsetGroup(tools, resources, prompts) + filtered := tsg.ForMCPRequest(MCPMethodToolsList, "") - if len(tsg.deprecatedAliases) != 3 { - t.Errorf("expected 3 aliases, got %d", len(tsg.deprecatedAliases)) + // tools/list should return all tools, no resources or prompts + if len(filtered.AvailableTools(context.Background())) != 2 { + t.Errorf("Expected 2 tools for tools/list, got %d", len(filtered.AvailableTools(context.Background()))) } - if tsg.deprecatedAliases["old_name"] != "new_name" { - t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsg.deprecatedAliases["old_name"]) + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for tools/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for tools/list, got %d", len(filtered.AvailablePrompts(context.Background()))) } - if tsg.deprecatedAliases["get_issue"] != "issue_read" { - t.Errorf("expected alias 'get_issue' -> 'issue_read'") +} + +func TestForMCPRequest_ToolsCall(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("create_issue", "issues", false), + mockTool("list_repos", "repos", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "get_me") + + available := filtered.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool for tools/call with name, got %d", len(available)) } - if tsg.deprecatedAliases["create_pr"] != "pull_request_create" { - t.Errorf("expected alias 'create_pr' -> 'pull_request_create'") + if available[0].Tool.Name != "get_me" { + t.Errorf("Expected tool name 'get_me', got %q", available[0].Tool.Name) } } -func TestResolveToolAliases(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestForMCPRequest_ToolsCall_NotFound(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) + filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for nonexistent tool, got %d", len(filtered.AvailableTools(context.Background()))) + } +} + +func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("list_commits", "repos", true), + } + + tsg := NewToolsetGroup(tools, nil, nil) tsg.AddDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - "create_pr": "pull_request_create", + "old_get_me": "get_me", }) - // Test resolving a mix of aliases and canonical names - input := []string{"get_issue", "some_tool", "create_pr"} - resolved, aliasesUsed := tsg.ResolveToolAliases(input) + // Request using the deprecated alias + filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") - // Verify resolved names - if len(resolved) != 3 { - t.Fatalf("expected 3 resolved names, got %d", len(resolved)) + available := filtered.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool when using deprecated alias, got %d", len(available)) } - if resolved[0] != "issue_read" { - t.Errorf("expected 'issue_read', got '%s'", resolved[0]) + if available[0].Tool.Name != "get_me" { + t.Errorf("Expected canonical name 'get_me', got %q", available[0].Tool.Name) } - if resolved[1] != "some_tool" { - t.Errorf("expected 'some_tool' (unchanged), got '%s'", resolved[1]) +} + +func TestForMCPRequest_ToolsCall_RespectsFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("create_issue", "issues", false), // write tool } - if resolved[2] != "pull_request_create" { - t.Errorf("expected 'pull_request_create', got '%s'", resolved[2]) + + tsg := NewToolsetGroup(tools, nil, nil) + // Apply read-only filter, then ForMCPRequest + filtered := tsg.WithReadOnly(true).ForMCPRequest(MCPMethodToolsCall, "create_issue") + + // The tool exists in the filtered group, but AvailableTools respects read-only + available := filtered.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools - write tool should be filtered by read-only, got %d", len(available)) } +} - // Verify aliasesUsed map - if len(aliasesUsed) != 2 { - t.Fatalf("expected 2 aliases used, got %d", len(aliasesUsed)) +func TestForMCPRequest_ResourcesList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), } - if aliasesUsed["get_issue"] != "issue_read" { - t.Errorf("expected aliasesUsed['get_issue'] = 'issue_read', got '%s'", aliasesUsed["get_issue"]) + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), } - if aliasesUsed["create_pr"] != "pull_request_create" { - t.Errorf("expected aliasesUsed['create_pr'] = 'pull_request_create', got '%s'", aliasesUsed["create_pr"]) + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + tsg := NewToolsetGroup(tools, resources, prompts) + filtered := tsg.ForMCPRequest(MCPMethodResourcesList, "") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for resources/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources for resources/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for resources/list, got %d", len(filtered.AvailablePrompts(context.Background()))) } } -func TestFindToolByName(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestForMCPRequest_ResourcesRead(t *testing.T) { + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), + } - // Create a toolset with a tool - toolset := NewToolset("test-toolset", "Test toolset") - toolset.readTools = append(toolset.readTools, mockTool("issue_read", true)) - tsg.AddToolset(toolset) + tsg := NewToolsetGroup(nil, resources, nil) + filtered := tsg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") - // Find by canonical name - tool, toolsetName, err := tsg.FindToolByName("issue_read") - if err != nil { - t.Fatalf("expected no error, got %v", err) + available := filtered.AvailableResourceTemplates(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 resource for resources/read, got %d", len(available)) } - if tool.Tool.Name != "issue_read" { - t.Errorf("expected tool name 'issue_read', got '%s'", tool.Tool.Name) + if available[0].Template.URITemplate != "repo://{owner}/{repo}" { + t.Errorf("Expected URI template 'repo://{owner}/{repo}', got %q", available[0].Template.URITemplate) + } +} + +func TestForMCPRequest_PromptsList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), } - if toolsetName != "test-toolset" { - t.Errorf("expected toolset name 'test-toolset', got '%s'", toolsetName) + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + mockPrompt("prompt2", "issues"), } - // FindToolByName does NOT resolve aliases - it expects canonical names - _, _, err = tsg.FindToolByName("get_issue") - if err == nil { - t.Error("expected error when using alias directly with FindToolByName") + tsg := NewToolsetGroup(tools, resources, prompts) + filtered := tsg.ForMCPRequest(MCPMethodPromptsList, "") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for prompts/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for prompts/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts for prompts/list, got %d", len(filtered.AvailablePrompts(context.Background()))) } } -func TestRegisterSpecificTools(t *testing.T) { - tsg := NewToolsetGroup(false) +func TestForMCPRequest_PromptsGet(t *testing.T) { + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + mockPrompt("prompt2", "issues"), + } + + tsg := NewToolsetGroup(nil, nil, prompts) + filtered := tsg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") - // Create a toolset with both read and write tools - toolset := NewToolset("test-toolset", "Test toolset") - toolset.readTools = append(toolset.readTools, mockTool("issue_read", true)) - toolset.writeTools = append(toolset.writeTools, mockTool("issue_write", false)) - tsg.AddToolset(toolset) + available := filtered.AvailablePrompts(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 prompt for prompts/get, got %d", len(available)) + } + if available[0].Prompt.Name != "prompt1" { + t.Errorf("Expected prompt name 'prompt1', got %q", available[0].Prompt.Name) + } +} - // deps is typed as any in toolsets package (to avoid circular deps) - var deps any +func TestForMCPRequest_UnknownMethod(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } - // Create a real server for testing - server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0.0"}, nil) + tsg := NewToolsetGroup(tools, resources, prompts) + filtered := tsg.ForMCPRequest("unknown/method", "") - // Test registering with canonical names - err := tsg.RegisterSpecificTools(server, []string{"issue_read"}, false, deps) - if err != nil { - t.Errorf("expected no error registering tool, got %v", err) + // Unknown methods should return empty + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for unknown method, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for unknown method, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for unknown method, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} - // Test registering write tool in read-only mode (should skip but not error) - err = tsg.RegisterSpecificTools(server, []string{"issue_write"}, true, deps) - if err != nil { - t.Errorf("expected no error when skipping write tool in read-only mode, got %v", err) +func TestForMCPRequest_Immutability(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), } - // Test registering non-existent tool (should error) - err = tsg.RegisterSpecificTools(server, []string{"nonexistent"}, false, deps) - if err == nil { - t.Error("expected error for non-existent tool") + original := NewToolsetGroup(tools, resources, prompts) + filtered := original.ForMCPRequest(MCPMethodToolsCall, "tool1") + + // Original should be unchanged + if len(original.AvailableTools(context.Background())) != 2 { + t.Errorf("Original was mutated! Expected 2 tools, got %d", len(original.AvailableTools(context.Background()))) + } + if len(original.AvailableResourceTemplates(context.Background())) != 1 { + t.Errorf("Original was mutated! Expected 1 resource, got %d", len(original.AvailableResourceTemplates(context.Background()))) + } + if len(original.AvailablePrompts(context.Background())) != 1 { + t.Errorf("Original was mutated! Expected 1 prompt, got %d", len(original.AvailablePrompts(context.Background()))) + } + + // Filtered should have only the requested tool + if len(filtered.AvailableTools(context.Background())) != 1 { + t.Errorf("Expected 1 tool in filtered, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources in filtered, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts in filtered, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ChainedWithOtherFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("create_issue", "issues", false), + mockTool("list_repos", "repos", true), + mockTool("delete_repo", "repos", false), + } + + tsg := NewToolsetGroup(tools, nil, nil) + tsg.SetDefaultToolsetIDs([]ToolsetID{"context", "repos"}) + + // Chain: default toolsets -> read-only -> specific method + filtered := tsg. + WithToolsets([]string{"default"}). + WithReadOnly(true). + ForMCPRequest(MCPMethodToolsList, "") + + available := filtered.AvailableTools(context.Background()) + + // Should have: get_me (context, read), list_repos (repos, read) + // Should NOT have: create_issue (issues not in default), delete_repo (write) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter chain, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["get_me"] { + t.Error("Expected get_me to be available") + } + if !toolNames["list_repos"] { + t.Error("Expected list_repos to be available") + } + if toolNames["create_issue"] { + t.Error("create_issue should not be available (toolset not enabled)") + } + if toolNames["delete_repo"] { + t.Error("delete_repo should not be available (write tool in read-only mode)") + } +} + +func TestForMCPRequest_ResourcesTemplatesList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + + tsg := NewToolsetGroup(tools, resources, nil) + filtered := tsg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") + + // Same behavior as resources/list + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 1 { + t.Errorf("Expected 1 resource, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } +} + +func TestMCPMethodConstants(t *testing.T) { + // Verify constants match expected MCP method names + tests := []struct { + constant string + expected string + }{ + {MCPMethodInitialize, "initialize"}, + {MCPMethodToolsList, "tools/list"}, + {MCPMethodToolsCall, "tools/call"}, + {MCPMethodResourcesList, "resources/list"}, + {MCPMethodResourcesRead, "resources/read"}, + {MCPMethodResourcesTemplatesList, "resources/templates/list"}, + {MCPMethodPromptsList, "prompts/list"}, + {MCPMethodPromptsGet, "prompts/get"}, + } + + for _, tt := range tests { + if tt.constant != tt.expected { + t.Errorf("Constant mismatch: got %q, expected %q", tt.constant, tt.expected) + } + } +} + +// mockToolWithFlags creates a ServerTool with feature flags for testing +func mockToolWithFlags(name string, toolsetID string, readOnly bool, enableFlag, disableFlag string) ServerTool { + tool := mockTool(name, toolsetID, readOnly) + tool.FeatureFlagEnable = enableFlag + tool.FeatureFlagDisable = disableFlag + return tool +} + +func TestFeatureFlagEnable(t *testing.T) { + tools := []ServerTool{ + mockTool("always_available", "toolset1", true), + mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), + } + + tsg := NewToolsetGroup(tools, nil, nil) + + // Without feature checker, tool with FeatureFlagEnable should be excluded + available := tsg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool without feature checker, got %d", len(available)) + } + if available[0].Tool.Name != "always_available" { + t.Errorf("Expected always_available, got %s", available[0].Tool.Name) + } + + // With feature checker returning false, tool should still be excluded + checkerFalse := func(_ context.Context, _ string) (bool, error) { return false, nil } + filteredFalse := tsg.WithFeatureChecker(checkerFalse) + availableFalse := filteredFalse.AvailableTools(context.Background()) + if len(availableFalse) != 1 { + t.Fatalf("Expected 1 tool with false checker, got %d", len(availableFalse)) + } + + // With feature checker returning true for "my_feature", tool should be included + checkerTrue := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + filteredTrue := tsg.WithFeatureChecker(checkerTrue) + availableTrue := filteredTrue.AvailableTools(context.Background()) + if len(availableTrue) != 2 { + t.Fatalf("Expected 2 tools with true checker, got %d", len(availableTrue)) + } +} + +func TestFeatureFlagDisable(t *testing.T) { + tools := []ServerTool{ + mockTool("always_available", "toolset1", true), + mockToolWithFlags("disabled_by_flag", "toolset1", true, "", "kill_switch"), + } + + tsg := NewToolsetGroup(tools, nil, nil) + + // Without feature checker, tool with FeatureFlagDisable should be included (flag is false) + available := tsg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools without feature checker, got %d", len(available)) + } + + // With feature checker returning true for "kill_switch", tool should be excluded + checkerTrue := func(_ context.Context, flag string) (bool, error) { + return flag == "kill_switch", nil + } + filtered := tsg.WithFeatureChecker(checkerTrue) + availableFiltered := filtered.AvailableTools(context.Background()) + if len(availableFiltered) != 1 { + t.Fatalf("Expected 1 tool with kill_switch enabled, got %d", len(availableFiltered)) + } + if availableFiltered[0].Tool.Name != "always_available" { + t.Errorf("Expected always_available, got %s", availableFiltered[0].Tool.Name) + } +} + +func TestFeatureFlagBoth(t *testing.T) { + // Tool that requires "new_feature" AND is disabled by "kill_switch" + tools := []ServerTool{ + mockToolWithFlags("complex_tool", "toolset1", true, "new_feature", "kill_switch"), + } + + tsg := NewToolsetGroup(tools, nil, nil) + + // Enable flag not set -> excluded + checker1 := func(_ context.Context, _ string) (bool, error) { return false, nil } + if len(tsg.WithFeatureChecker(checker1).AvailableTools(context.Background())) != 0 { + t.Error("Tool should be excluded when enable flag is false") + } + + // Enable flag set, disable flag not set -> included + checker2 := func(_ context.Context, flag string) (bool, error) { return flag == "new_feature", nil } + if len(tsg.WithFeatureChecker(checker2).AvailableTools(context.Background())) != 1 { + t.Error("Tool should be included when enable flag is true and disable flag is false") + } + + // Enable flag set, disable flag also set -> excluded (disable wins) + checker3 := func(_ context.Context, _ string) (bool, error) { return true, nil } + if len(tsg.WithFeatureChecker(checker3).AvailableTools(context.Background())) != 0 { + t.Error("Tool should be excluded when both flags are true (disable wins)") + } +} + +func TestFeatureFlagError(t *testing.T) { + tools := []ServerTool{ + mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), + } + + tsg := NewToolsetGroup(tools, nil, nil) + + // Checker that returns error should treat as false (tool excluded) + checkerError := func(_ context.Context, _ string) (bool, error) { + return false, fmt.Errorf("simulated error") + } + filtered := tsg.WithFeatureChecker(checkerError) + available := filtered.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when checker errors, got %d", len(available)) + } +} + +func TestFeatureFlagResources(t *testing.T) { + resources := []ServerResourceTemplate{ + mockResource("always_available", "toolset1", "uri1"), + { + Template: mcp.ResourceTemplate{Name: "needs_flag", URITemplate: "uri2"}, + Toolset: testToolsetMetadata("toolset1"), + FeatureFlagEnable: "my_feature", + }, + } + + tsg := NewToolsetGroup(nil, resources, nil) + + // Without checker, resource with enable flag should be excluded + available := tsg.AvailableResourceTemplates(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 resource without checker, got %d", len(available)) + } + + // With checker returning true, both should be included + checker := func(_ context.Context, _ string) (bool, error) { return true, nil } + filtered := tsg.WithFeatureChecker(checker) + if len(filtered.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources with checker, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } +} + +func TestFeatureFlagPrompts(t *testing.T) { + prompts := []ServerPrompt{ + mockPrompt("always_available", "toolset1"), + { + Prompt: mcp.Prompt{Name: "needs_flag"}, + Toolset: testToolsetMetadata("toolset1"), + FeatureFlagEnable: "my_feature", + }, + } + + tsg := NewToolsetGroup(nil, nil, prompts) + + // Without checker, prompt with enable flag should be excluded + available := tsg.AvailablePrompts(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 prompt without checker, got %d", len(available)) + } + + // With checker returning true, both should be included + checker := func(_ context.Context, _ string) (bool, error) { return true, nil } + filtered := tsg.WithFeatureChecker(checker) + if len(filtered.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts with checker, got %d", len(filtered.AvailablePrompts(context.Background()))) } } From 7a78ec6f72e8e0c9def6c7699a79230d7d0a0a34 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 00:05:26 +0100 Subject: [PATCH 02/27] Add validation tests for tools, resources, and prompts metadata This commit adds comprehensive validation tests to ensure all MCP items have required metadata: - TestAllToolsHaveRequiredMetadata: Validates Toolset.ID and Annotations - TestAllToolsHaveValidToolsetID: Ensures toolsets are in AvailableToolsets() - TestAllResourcesHaveRequiredMetadata: Validates resource metadata - TestAllPromptsHaveRequiredMetadata: Validates prompt metadata - TestToolReadOnlyHintConsistency: Validates IsReadOnly() matches annotation - TestNoDuplicate*Names: Ensures unique names across tools/resources/prompts - TestAllToolsHaveHandlerFunc: Ensures all tools have handlers - TestDefaultToolsetsAreValid: Validates default toolset IDs - TestToolsetMetadataConsistency: Ensures consistent descriptions per toolset Also fixes a bug discovered by these tests: ToolsetMetadataGit was defined but not added to AvailableToolsets(), causing get_repository_tree to have an invalid toolset ID. --- pkg/github/tools.go | 1 + pkg/github/tools_validation_test.go | 219 ++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 pkg/github/tools_validation_test.go diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 1fe23dfd2..f39ae43c1 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -111,6 +111,7 @@ func AvailableToolsets() []toolsets.ToolsetMetadata { return []toolsets.ToolsetMetadata{ ToolsetMetadataContext, ToolsetMetadataRepos, + ToolsetMetadataGit, ToolsetMetadataIssues, ToolsetMetadataPullRequests, ToolsetMetadataUsers, diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go new file mode 100644 index 000000000..e8ce93f67 --- /dev/null +++ b/pkg/github/tools_validation_test.go @@ -0,0 +1,219 @@ +package github + +import ( + "testing" + + "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubTranslation is a simple translation function for testing +func stubTranslation(_, fallback string) string { + return fallback +} + +// TestAllToolsHaveRequiredMetadata validates that all tools have mandatory metadata: +// - Toolset must be set (non-empty ID) +// - ReadOnlyHint annotation must be explicitly set (not nil) +func TestAllToolsHaveRequiredMetadata(t *testing.T) { + tools := AllTools(stubTranslation) + + require.NotEmpty(t, tools, "AllTools should return at least one tool") + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, tool.Toolset.ID, + "Tool %q must have a Toolset.ID", tool.Tool.Name) + + // Toolset description should be set for documentation + assert.NotEmpty(t, tool.Toolset.Description, + "Tool %q should have a Toolset.Description", tool.Tool.Name) + + // Annotations must exist and have ReadOnlyHint explicitly set + require.NotNil(t, tool.Tool.Annotations, + "Tool %q must have Annotations set (for ReadOnlyHint)", tool.Tool.Name) + + // We can't distinguish between "not set" and "set to false" for a bool, + // but having Annotations non-nil confirms the developer thought about it. + // The ReadOnlyHint value itself is validated by ensuring Annotations exist. + }) + } +} + +// TestAllToolsHaveValidToolsetID validates that all tools belong to known toolsets +func TestAllToolsHaveValidToolsetID(t *testing.T) { + tools := AllTools(stubTranslation) + validToolsetIDs := GetValidToolsetIDs() + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + assert.True(t, validToolsetIDs[tool.Toolset.ID], + "Tool %q has invalid Toolset.ID %q - must be one of the defined toolsets", + tool.Tool.Name, tool.Toolset.ID) + }) + } +} + +// TestAllResourcesHaveRequiredMetadata validates that all resources have mandatory metadata +func TestAllResourcesHaveRequiredMetadata(t *testing.T) { + // Resources need client functions, but we can pass nil for validation + // since we're not actually calling handlers + resources := AllResources(stubTranslation, nil, nil) + + require.NotEmpty(t, resources, "AllResources should return at least one resource") + + for _, res := range resources { + t.Run(res.Template.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, res.Toolset.ID, + "Resource %q must have a Toolset.ID", res.Template.Name) + + // Handler must be set + assert.NotNil(t, res.Handler, + "Resource %q must have a Handler", res.Template.Name) + }) + } +} + +// TestAllPromptsHaveRequiredMetadata validates that all prompts have mandatory metadata +func TestAllPromptsHaveRequiredMetadata(t *testing.T) { + prompts := AllPrompts(stubTranslation) + + require.NotEmpty(t, prompts, "AllPrompts should return at least one prompt") + + for _, prompt := range prompts { + t.Run(prompt.Prompt.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, prompt.Toolset.ID, + "Prompt %q must have a Toolset.ID", prompt.Prompt.Name) + + // Handler must be set + assert.NotNil(t, prompt.Handler, + "Prompt %q must have a Handler", prompt.Prompt.Name) + }) + } +} + +// TestAllResourcesHaveValidToolsetID validates that all resources belong to known toolsets +func TestAllResourcesHaveValidToolsetID(t *testing.T) { + resources := AllResources(stubTranslation, nil, nil) + validToolsetIDs := GetValidToolsetIDs() + + for _, res := range resources { + t.Run(res.Template.Name, func(t *testing.T) { + assert.True(t, validToolsetIDs[res.Toolset.ID], + "Resource %q has invalid Toolset.ID %q", res.Template.Name, res.Toolset.ID) + }) + } +} + +// TestAllPromptsHaveValidToolsetID validates that all prompts belong to known toolsets +func TestAllPromptsHaveValidToolsetID(t *testing.T) { + prompts := AllPrompts(stubTranslation) + validToolsetIDs := GetValidToolsetIDs() + + for _, prompt := range prompts { + t.Run(prompt.Prompt.Name, func(t *testing.T) { + assert.True(t, validToolsetIDs[prompt.Toolset.ID], + "Prompt %q has invalid Toolset.ID %q", prompt.Prompt.Name, prompt.Toolset.ID) + }) + } +} + +// TestToolReadOnlyHintConsistency validates that read-only tools are correctly annotated +func TestToolReadOnlyHintConsistency(t *testing.T) { + tools := AllTools(stubTranslation) + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + require.NotNil(t, tool.Tool.Annotations, + "Tool %q must have Annotations", tool.Tool.Name) + + // Verify IsReadOnly() method matches the annotation + assert.Equal(t, tool.Tool.Annotations.ReadOnlyHint, tool.IsReadOnly(), + "Tool %q: IsReadOnly() should match Annotations.ReadOnlyHint", tool.Tool.Name) + }) + } +} + +// TestNoDuplicateToolNames ensures all tools have unique names +func TestNoDuplicateToolNames(t *testing.T) { + tools := AllTools(stubTranslation) + seen := make(map[string]bool) + + for _, tool := range tools { + name := tool.Tool.Name + assert.False(t, seen[name], + "Duplicate tool name found: %q", name) + seen[name] = true + } +} + +// TestNoDuplicateResourceNames ensures all resources have unique names +func TestNoDuplicateResourceNames(t *testing.T) { + resources := AllResources(stubTranslation, nil, nil) + seen := make(map[string]bool) + + for _, res := range resources { + name := res.Template.Name + assert.False(t, seen[name], + "Duplicate resource name found: %q", name) + seen[name] = true + } +} + +// TestNoDuplicatePromptNames ensures all prompts have unique names +func TestNoDuplicatePromptNames(t *testing.T) { + prompts := AllPrompts(stubTranslation) + seen := make(map[string]bool) + + for _, prompt := range prompts { + name := prompt.Prompt.Name + assert.False(t, seen[name], + "Duplicate prompt name found: %q", name) + seen[name] = true + } +} + +// TestAllToolsHaveHandlerFunc ensures all tools have a handler function +func TestAllToolsHaveHandlerFunc(t *testing.T) { + tools := AllTools(stubTranslation) + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + assert.NotNil(t, tool.HandlerFunc, + "Tool %q must have a HandlerFunc", tool.Tool.Name) + }) + } +} + +// TestDefaultToolsetsAreValid ensures default toolset IDs are all valid +func TestDefaultToolsetsAreValid(t *testing.T) { + defaults := GetDefaultToolsetIDs() + valid := GetValidToolsetIDs() + + for _, id := range defaults { + assert.True(t, valid[id], + "Default toolset ID %q is not in the valid toolset list", id) + } +} + +// TestToolsetMetadataConsistency ensures tools in the same toolset have consistent descriptions +func TestToolsetMetadataConsistency(t *testing.T) { + tools := AllTools(stubTranslation) + toolsetDescriptions := make(map[toolsets.ToolsetID]string) + + for _, tool := range tools { + id := tool.Toolset.ID + desc := tool.Toolset.Description + + if existing, ok := toolsetDescriptions[id]; ok { + assert.Equal(t, existing, desc, + "Toolset %q has inconsistent descriptions across tools", id) + } else { + toolsetDescriptions[id] = desc + } + } +} From 689a040fe38ab9a82181268db1ae6c37ce10065c Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 00:09:45 +0100 Subject: [PATCH 03/27] Fix default toolsets behavior when not in dynamic mode When no toolsets are specified and dynamic mode is disabled, the server should use the default toolsets. The bug was introduced when adding dynamic toolsets support: 1. CleanToolsets(nil) was converting nil to empty slice 2. Empty slice passed to WithToolsets means 'no toolsets' 3. This resulted in zero tools being registered Fix: Preserve nil for non-dynamic mode (nil = use defaults in WithToolsets) and only set empty slice when dynamic mode is enabled without explicit toolsets. --- cmd/github-mcp-server/main.go | 10 ++++++++-- internal/ghmcp/server.go | 32 +++++++++++++++++--------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 84c974dad..034b0e238 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -41,10 +41,16 @@ var ( // it's because viper doesn't handle comma-separated values correctly for env // vars when using GetStringSlice. // https://github.com/spf13/viper/issues/380 + // + // Additionally, viper.UnmarshalKey returns an empty slice even when the flag + // is not set, but we need nil to indicate "use defaults". So we check IsSet first. var enabledToolsets []string - if err := viper.UnmarshalKey("toolsets", &enabledToolsets); err != nil { - return fmt.Errorf("failed to unmarshal toolsets: %w", err) + if viper.IsSet("toolsets") { + if err := viper.UnmarshalKey("toolsets", &enabledToolsets); err != nil { + return fmt.Errorf("failed to unmarshal toolsets: %w", err) + } } + // else: enabledToolsets stays nil, meaning "use defaults" // Parse tools (similar to toolsets) var enabledTools []string diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 0edca88ed..99898f8b1 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -103,14 +103,24 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) } - enabledToolsets := cfg.EnabledToolsets - - // Clean up the passed toolsets (removes duplicates, whitespace) - enabledToolsets, invalidToolsets := github.CleanToolsets(enabledToolsets) - - if len(invalidToolsets) > 0 { - fmt.Fprintf(os.Stderr, "Invalid toolsets ignored: %s\n", strings.Join(invalidToolsets, ", ")) + // Determine enabled toolsets based on configuration: + // - nil means "use defaults" (unless dynamic mode without explicit toolsets) + // - empty slice means "no toolsets" (for dynamic mode to enable on demand) + // - explicit list means "use these toolsets" + var enabledToolsets []string + if cfg.EnabledToolsets != nil { + // Clean up explicitly passed toolsets (removes duplicates, whitespace) + var invalidToolsets []string + enabledToolsets, invalidToolsets = github.CleanToolsets(cfg.EnabledToolsets) + if len(invalidToolsets) > 0 { + fmt.Fprintf(os.Stderr, "Invalid toolsets ignored: %s\n", strings.Join(invalidToolsets, ", ")) + } + } else if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start with no toolsets enabled + // so users can enable them on demand via the dynamic tools + enabledToolsets = []string{} } + // else: enabledToolsets stays nil, which means "use defaults" in WithToolsets // Generate instructions based on enabled toolsets instructions := github.GenerateInstructions(enabledToolsets) @@ -162,14 +172,6 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { // Clean tool names (WithTools will resolve any deprecated aliases) enabledTools := github.CleanTools(cfg.EnabledTools) - // For dynamic toolsets mode: - // - If toolsets are explicitly provided (including "default"), honor them - // - If no toolsets are specified (nil), start with no toolsets enabled (empty slice) - // so users can enable them on demand via the dynamic tools - if cfg.DynamicToolsets && cfg.EnabledToolsets == nil { - enabledToolsets = []string{} - } - // Apply filters based on configuration // - WithReadOnly: filters out write tools when true // - WithToolsets: nil=defaults, empty=none, handles "all"/"default" keywords From a03b3bf0bfbd4338c4280ad09150795e66cddf08 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 12:43:04 +0100 Subject: [PATCH 04/27] refactor: address PR review feedback for toolsets - Rename AddDeprecatedToolAliases to WithDeprecatedToolAliases for immutable filter chain consistency (returns new ToolsetGroup) - Remove unused mockGetRawClient from generate_docs.go (use nil instead) - Remove legacy ServerTool functions (NewServerToolLegacy and NewServerToolFromHandlerLegacy) - no usages - Add panic in Handler()/RegisterFunc() when HandlerFunc is nil - Add HasHandler() method for checking if tool has a handler - Add tests for HasHandler and nil handler panic behavior - Update all tests to use new WithDeprecatedToolAliases pattern --- cmd/github-mcp-server/generate_docs.go | 10 +--- internal/ghmcp/server.go | 6 +-- pkg/github/tools_validation_test.go | 2 + pkg/toolsets/server_tool.go | 44 ++++------------ pkg/toolsets/toolsets.go | 15 ++++-- pkg/toolsets/toolsets_test.go | 73 +++++++++++++++++++------- 6 files changed, 81 insertions(+), 69 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 785bd3ff0..dd3e183cb 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" @@ -37,11 +36,6 @@ func mockGetClient(_ context.Context) (*gogithub.Client, error) { return gogithub.NewClient(nil), nil } -// mockGetRawClient returns a mock raw client for documentation generation -func mockGetRawClient(_ context.Context) (*raw.Client, error) { - return nil, nil -} - func generateAllDocs() error { if err := generateReadmeDocs("README.md"); err != nil { return fmt.Errorf("failed to generate README docs: %w", err) @@ -63,7 +57,7 @@ func generateReadmeDocs(readmePath string) error { t, _ := translations.TranslationHelper() // Create toolset group with mock clients (no deps needed for doc generation) - tsg := github.NewToolsetGroup(t, mockGetClient, mockGetRawClient) + tsg := github.NewToolsetGroup(t, mockGetClient, nil) // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(tsg) @@ -306,7 +300,7 @@ func generateRemoteToolsetsDoc() string { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - tsg := github.NewToolsetGroup(t, mockGetClient, mockGetRawClient) + tsg := github.NewToolsetGroup(t, mockGetClient, nil) // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 99898f8b1..54f10290d 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -165,19 +165,17 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { // Create toolset group with all tools, resources, and prompts tsg := github.NewToolsetGroup(cfg.Translator, getClient, getRawClient) - // Add deprecated tool aliases for backward compatibility - // See docs/deprecated-tool-aliases.md for the full list of renames - tsg.AddDeprecatedToolAliases(github.DeprecatedToolAliases) - // Clean tool names (WithTools will resolve any deprecated aliases) enabledTools := github.CleanTools(cfg.EnabledTools) // Apply filters based on configuration + // - WithDeprecatedToolAliases: adds backward compatibility aliases // - WithReadOnly: filters out write tools when true // - WithToolsets: nil=defaults, empty=none, handles "all"/"default" keywords // - WithTools: additional tools that bypass toolset filtering (additive, resolves aliases) // - WithFeatureChecker: filters based on feature flags filteredTsg := tsg. + WithDeprecatedToolAliases(github.DeprecatedToolAliases). WithReadOnly(cfg.ReadOnly). WithToolsets(enabledToolsets). WithTools(enabledTools). diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go index e8ce93f67..3b956326a 100644 --- a/pkg/github/tools_validation_test.go +++ b/pkg/github/tools_validation_test.go @@ -185,6 +185,8 @@ func TestAllToolsHaveHandlerFunc(t *testing.T) { t.Run(tool.Tool.Name, func(t *testing.T) { assert.NotNil(t, tool.HandlerFunc, "Tool %q must have a HandlerFunc", tool.Tool.Name) + assert.True(t, tool.HasHandler(), + "Tool %q HasHandler() should return true", tool.Tool.Name) }) } } diff --git a/pkg/toolsets/server_tool.go b/pkg/toolsets/server_tool.go index 334492c74..0e782e631 100644 --- a/pkg/toolsets/server_tool.go +++ b/pkg/toolsets/server_tool.go @@ -57,17 +57,24 @@ func (st *ServerTool) IsReadOnly() bool { return st.Tool.Annotations != nil && st.Tool.Annotations.ReadOnlyHint } +// HasHandler returns true if this tool has a handler function. +func (st *ServerTool) HasHandler() bool { + return st.HandlerFunc != nil +} + // Handler returns a tool handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all tools should have handlers. func (st *ServerTool) Handler(deps any) mcp.ToolHandler { if st.HandlerFunc == nil { - return nil + panic("HandlerFunc is nil for tool: " + st.Tool.Name) } return st.HandlerFunc(deps) } // RegisterFunc registers the tool with the server using the provided dependencies. +// Panics if the tool has no handler - all tools should have handlers. func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { - handler := st.Handler(deps) + handler := st.Handler(deps) // This will panic if HandlerFunc is nil s.AddTool(&st.Tool, handler) } @@ -97,36 +104,3 @@ func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, hand func NewServerToolFromHandler(tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandler) ServerTool { return ServerTool{Tool: tool, Toolset: toolset, HandlerFunc: handlerFn} } - -// NewServerToolLegacy creates a ServerTool from a tool definition, toolset metadata, and an already-bound typed handler. -// This is for backward compatibility during the refactor - the handler doesn't use dependencies. -// Deprecated: Use NewServerTool instead for new code. -func NewServerToolLegacy[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandlerFor[In, Out]) ServerTool { - return ServerTool{ - Tool: tool, - Toolset: toolset, - HandlerFunc: func(_ any) mcp.ToolHandler { - return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - var arguments In - if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err - } - resp, _, err := handler(ctx, req, arguments) - return resp, err - } - }, - } -} - -// NewServerToolFromHandlerLegacy creates a ServerTool from a tool definition, toolset metadata, and an already-bound raw handler. -// This is for backward compatibility during the refactor - the handler doesn't use dependencies. -// Deprecated: Use NewServerToolFromHandler instead for new code. -func NewServerToolFromHandlerLegacy(tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandler) ServerTool { - return ServerTool{ - Tool: tool, - Toolset: toolset, - HandlerFunc: func(_ any) mcp.ToolHandler { - return handler - }, - } -} diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index e42a9bcae..5b1df37ff 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -413,12 +413,19 @@ func (tg *ToolsetGroup) filterPromptsByName(name string) []ServerPrompt { return []ServerPrompt{} } -// AddDeprecatedToolAliases adds mappings from old tool names to new canonical names. -func (tg *ToolsetGroup) AddDeprecatedToolAliases(aliases map[string]string) *ToolsetGroup { +// WithDeprecatedToolAliases returns a new ToolsetGroup with the given deprecated aliases added. +// Aliases map old tool names to new canonical names. +func (tg *ToolsetGroup) WithDeprecatedToolAliases(aliases map[string]string) *ToolsetGroup { + newTG := tg.copy() + // Ensure we have a fresh map + newTG.deprecatedAliases = make(map[string]string, len(tg.deprecatedAliases)+len(aliases)) + for k, v := range tg.deprecatedAliases { + newTG.deprecatedAliases[k] = v + } for oldName, newName := range aliases { - tg.deprecatedAliases[oldName] = newName + newTG.deprecatedAliases[oldName] = newName } - return tg + return newTG } // isToolsetEnabled checks if a toolset is enabled based on current filters. diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index 7bec55848..34d9ba753 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -252,22 +252,27 @@ func TestToolsForToolset(t *testing.T) { } } -func TestAddDeprecatedToolAliases(t *testing.T) { +func TestWithDeprecatedToolAliases(t *testing.T) { tools := []ServerTool{ mockTool("new_name", "toolset1", true), } tsg := NewToolsetGroup(tools, nil, nil) - tsg.AddDeprecatedToolAliases(map[string]string{ + tsgWithAliases := tsg.WithDeprecatedToolAliases(map[string]string{ "old_name": "new_name", "get_issue": "issue_read", }) - if len(tsg.deprecatedAliases) != 2 { - t.Errorf("expected 2 aliases, got %d", len(tsg.deprecatedAliases)) + // Original should be unchanged (immutable) + if len(tsg.deprecatedAliases) != 0 { + t.Errorf("original should have 0 aliases, got %d", len(tsg.deprecatedAliases)) } - if tsg.deprecatedAliases["old_name"] != "new_name" { - t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsg.deprecatedAliases["old_name"]) + + if len(tsgWithAliases.deprecatedAliases) != 2 { + t.Errorf("expected 2 aliases, got %d", len(tsgWithAliases.deprecatedAliases)) + } + if tsgWithAliases.deprecatedAliases["old_name"] != "new_name" { + t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsgWithAliases.deprecatedAliases["old_name"]) } } @@ -277,10 +282,10 @@ func TestResolveToolAliases(t *testing.T) { mockTool("some_tool", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil) - tsg.AddDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - }) + tsg := NewToolsetGroup(tools, nil, nil). + WithDeprecatedToolAliases(map[string]string{ + "get_issue": "issue_read", + }) // Test resolving a mix of aliases and canonical names input := []string{"get_issue", "some_tool"} @@ -384,10 +389,10 @@ func TestWithToolsResolvesAliases(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil) - tsg.AddDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - }) + tsg := NewToolsetGroup(tools, nil, nil). + WithDeprecatedToolAliases(map[string]string{ + "get_issue": "issue_read", + }) // Using deprecated alias should resolve to canonical name filtered := tsg.WithToolsets([]string{}).WithTools([]string{"get_issue"}) @@ -593,10 +598,10 @@ func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { mockTool("list_commits", "repos", true), } - tsg := NewToolsetGroup(tools, nil, nil) - tsg.AddDeprecatedToolAliases(map[string]string{ - "old_get_me": "get_me", - }) + tsg := NewToolsetGroup(tools, nil, nil). + WithDeprecatedToolAliases(map[string]string{ + "old_get_me": "get_me", + }) // Request using the deprecated alias filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") @@ -1033,3 +1038,35 @@ func TestFeatureFlagPrompts(t *testing.T) { t.Errorf("Expected 2 prompts with checker, got %d", len(filtered.AvailablePrompts(context.Background()))) } } + +func TestServerToolHasHandler(t *testing.T) { + // Tool with handler + toolWithHandler := mockTool("has_handler", "toolset1", true) + if !toolWithHandler.HasHandler() { + t.Error("Expected HasHandler() to return true for tool with handler") + } + + // Tool without handler + toolWithoutHandler := ServerTool{ + Tool: mcp.Tool{Name: "no_handler"}, + Toolset: testToolsetMetadata("toolset1"), + } + if toolWithoutHandler.HasHandler() { + t.Error("Expected HasHandler() to return false for tool without handler") + } +} + +func TestServerToolHandlerPanicOnNil(t *testing.T) { + tool := ServerTool{ + Tool: mcp.Tool{Name: "no_handler"}, + Toolset: testToolsetMetadata("toolset1"), + } + + defer func() { + if r := recover(); r == nil { + t.Error("Expected Handler() to panic when HandlerFunc is nil") + } + }() + + tool.Handler(nil) +} From 8b850d00f2d966d7a2be0a84402115d2a42e0474 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 19:06:00 +0100 Subject: [PATCH 05/27] refactor: Apply HandlerFunc pattern to resources for stateless NewToolsetGroup This change applies the same HandlerFunc pattern used by tools to resources, allowing NewToolsetGroup to be fully stateless (only requiring translations). Key changes: - Add ResourceHandlerFunc type to toolsets package - Update ServerResourceTemplate to use HandlerFunc instead of direct Handler - Add HasHandler() and Handler(deps) methods to ServerResourceTemplate - Update RegisterResourceTemplates to take deps parameter - Refactor repository resource definitions to use HandlerFunc pattern - Make AllResources(t) stateless (only takes translations) - Make NewToolsetGroup(t) stateless (only takes translations) - Update generate_docs.go - no longer needs mock clients - Update tests to use new patterns This resolves the concern about mixed concerns in doc generation - the toolset metadata and resource templates can now be created without any runtime dependencies, while handlers are generated on-demand when deps are provided during registration. --- cmd/github-mcp-server/generate_docs.go | 15 +++------ internal/ghmcp/server.go | 4 +-- pkg/github/repository_resource.go | 38 +++++++++++++--------- pkg/github/repository_resource_test.go | 45 +++++++++++++------------- pkg/github/resources.go | 15 ++++----- pkg/github/tools_validation_test.go | 15 ++++----- pkg/github/toolset_group.go | 9 +++--- pkg/toolsets/toolsets.go | 38 +++++++++++++++++----- pkg/toolsets/toolsets_test.go | 6 ++-- 9 files changed, 104 insertions(+), 81 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index dd3e183cb..1e3f3252b 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "net/url" "os" @@ -12,7 +11,6 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" - gogithub "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/spf13/cobra" @@ -31,11 +29,6 @@ func init() { rootCmd.AddCommand(generateDocsCmd) } -// mockGetClient returns a mock GitHub client for documentation generation -func mockGetClient(_ context.Context) (*gogithub.Client, error) { - return gogithub.NewClient(nil), nil -} - func generateAllDocs() error { if err := generateReadmeDocs("README.md"); err != nil { return fmt.Errorf("failed to generate README docs: %w", err) @@ -56,8 +49,8 @@ func generateReadmeDocs(readmePath string) error { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients (no deps needed for doc generation) - tsg := github.NewToolsetGroup(t, mockGetClient, nil) + // Create toolset group - stateless, no dependencies needed for doc generation + tsg := github.NewToolsetGroup(t) // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(tsg) @@ -299,8 +292,8 @@ func generateRemoteToolsetsDoc() string { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients - tsg := github.NewToolsetGroup(t, mockGetClient, nil) + // Create toolset group - stateless + tsg := github.NewToolsetGroup(t) // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 54f10290d..1dec59381 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -162,8 +162,8 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { ContentWindowSize: cfg.ContentWindowSize, } - // Create toolset group with all tools, resources, and prompts - tsg := github.NewToolsetGroup(cfg.Translator, getClient, getRawClient) + // Create toolset group with all tools, resources, and prompts (stateless) + tsg := github.NewToolsetGroup(cfg.Translator) // Clean tool names (WithTools will resolve any deprecated aliases) enabledTools := github.CleanTools(cfg.EnabledTools) diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index d8fd13963..6dbbe90ec 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -29,8 +29,8 @@ var ( repositoryResourcePrContentURITemplate = uritemplate.MustNew("repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}") ) -// GetRepositoryResourceContent defines the resource template and handler for getting repository content. -func GetRepositoryResourceContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { +// GetRepositoryResourceContent defines the resource template for getting repository content. +func GetRepositoryResourceContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { return toolsets.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ @@ -38,12 +38,12 @@ func GetRepositoryResourceContent(getClient GetClientFn, getRawClient raw.GetRaw URITemplate: repositoryResourceContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_DESCRIPTION", "Repository Content"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate), + repositoryResourceContentsHandlerFunc(repositoryResourceContentURITemplate), ) } -// GetRepositoryResourceBranchContent defines the resource template and handler for getting repository content for a branch. -func GetRepositoryResourceBranchContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { +// GetRepositoryResourceBranchContent defines the resource template for getting repository content for a branch. +func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { return toolsets.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ @@ -51,12 +51,12 @@ func GetRepositoryResourceBranchContent(getClient GetClientFn, getRawClient raw. URITemplate: repositoryResourceBranchContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_BRANCH_DESCRIPTION", "Repository Content for specific branch"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate), + repositoryResourceContentsHandlerFunc(repositoryResourceBranchContentURITemplate), ) } -// GetRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. -func GetRepositoryResourceCommitContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { +// GetRepositoryResourceCommitContent defines the resource template for getting repository content for a commit. +func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { return toolsets.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ @@ -64,12 +64,12 @@ func GetRepositoryResourceCommitContent(getClient GetClientFn, getRawClient raw. URITemplate: repositoryResourceCommitContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_COMMIT_DESCRIPTION", "Repository Content for specific commit"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate), + repositoryResourceContentsHandlerFunc(repositoryResourceCommitContentURITemplate), ) } -// GetRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. -func GetRepositoryResourceTagContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { +// GetRepositoryResourceTagContent defines the resource template for getting repository content for a tag. +func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { return toolsets.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ @@ -77,12 +77,12 @@ func GetRepositoryResourceTagContent(getClient GetClientFn, getRawClient raw.Get URITemplate: repositoryResourceTagContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_TAG_DESCRIPTION", "Repository Content for specific tag"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate), + repositoryResourceContentsHandlerFunc(repositoryResourceTagContentURITemplate), ) } -// GetRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. -func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { +// GetRepositoryResourcePrContent defines the resource template for getting repository content for a pull request. +func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { return toolsets.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ @@ -90,10 +90,18 @@ func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetR URITemplate: repositoryResourcePrContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_PR_DESCRIPTION", "Repository Content for specific pull request"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate), + repositoryResourceContentsHandlerFunc(repositoryResourcePrContentURITemplate), ) } +// repositoryResourceContentsHandlerFunc returns a ResourceHandlerFunc that creates handlers on-demand. +func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) toolsets.ResourceHandlerFunc { + return func(deps any) mcp.ResourceHandler { + d := deps.(ToolDependencies) + return RepositoryResourceContentsHandler(d.GetClient, d.GetRawClient, resourceURITemplate) + } +} + // RepositoryResourceContentsHandler returns a handler function for repository content requests. func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { return func(ctx context.Context, request *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index 1b4120ff0..f938a57f5 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -28,7 +27,7 @@ func Test_repositoryResourceContents(t *testing.T) { name string mockedClient *http.Client uri string - handlerFn func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler + handlerFn func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler expectedResponseType resourceResponseType expectError string expectedResult *mcp.ReadResourceResult @@ -46,8 +45,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo:///repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "owner is required", @@ -65,8 +64,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner//refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceBranchContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "repo is required", @@ -84,8 +83,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/data.png", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeBlob, expectedResult: &mcp.ReadResourceResult{ @@ -108,8 +107,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -134,8 +133,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/pkg/github/actions.go", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -158,8 +157,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceBranchContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -182,8 +181,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/tags/v1.0.0/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceTagContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -206,8 +205,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/sha/abc123/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceCommitContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -238,8 +237,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/pull/42/head/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourcePrContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -261,8 +260,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/nonexistent.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - return GetRepositoryResourceContent(getClient, getRawClient, t).Handler + handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "404 Not Found", @@ -273,7 +272,7 @@ func Test_repositoryResourceContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, base) - handler := tc.handlerFn(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper) + handler := tc.handlerFn(stubGetClientFn(client), stubGetRawClientFn(mockRawClient)) request := &mcp.ReadResourceRequest{ Params: &mcp.ReadResourceParams{ diff --git a/pkg/github/resources.go b/pkg/github/resources.go index f0b07e831..253c4bc11 100644 --- a/pkg/github/resources.go +++ b/pkg/github/resources.go @@ -1,20 +1,19 @@ package github import ( - "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" ) // AllResources returns all resource templates with their embedded toolset metadata. -// Resource template functions return ServerResourceTemplate directly with toolset info. -func AllResources(t translations.TranslationHelperFunc, getClient GetClientFn, getRawClient raw.GetRawClientFn) []toolsets.ServerResourceTemplate { +// Resource definitions are stateless - handlers are generated on-demand during registration. +func AllResources(t translations.TranslationHelperFunc) []toolsets.ServerResourceTemplate { return []toolsets.ServerResourceTemplate{ // Repository resources - GetRepositoryResourceContent(getClient, getRawClient, t), - GetRepositoryResourceBranchContent(getClient, getRawClient, t), - GetRepositoryResourceCommitContent(getClient, getRawClient, t), - GetRepositoryResourceTagContent(getClient, getRawClient, t), - GetRepositoryResourcePrContent(getClient, getRawClient, t), + GetRepositoryResourceContent(t), + GetRepositoryResourceBranchContent(t), + GetRepositoryResourceCommitContent(t), + GetRepositoryResourceTagContent(t), + GetRepositoryResourcePrContent(t), } } diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go index 3b956326a..ae79c455a 100644 --- a/pkg/github/tools_validation_test.go +++ b/pkg/github/tools_validation_test.go @@ -58,9 +58,8 @@ func TestAllToolsHaveValidToolsetID(t *testing.T) { // TestAllResourcesHaveRequiredMetadata validates that all resources have mandatory metadata func TestAllResourcesHaveRequiredMetadata(t *testing.T) { - // Resources need client functions, but we can pass nil for validation - // since we're not actually calling handlers - resources := AllResources(stubTranslation, nil, nil) + // Resources are now stateless - no client functions needed + resources := AllResources(stubTranslation) require.NotEmpty(t, resources, "AllResources should return at least one resource") @@ -70,9 +69,9 @@ func TestAllResourcesHaveRequiredMetadata(t *testing.T) { assert.NotEmpty(t, res.Toolset.ID, "Resource %q must have a Toolset.ID", res.Template.Name) - // Handler must be set - assert.NotNil(t, res.Handler, - "Resource %q must have a Handler", res.Template.Name) + // HandlerFunc must be set + assert.True(t, res.HasHandler(), + "Resource %q must have a HandlerFunc", res.Template.Name) }) } } @@ -98,7 +97,7 @@ func TestAllPromptsHaveRequiredMetadata(t *testing.T) { // TestAllResourcesHaveValidToolsetID validates that all resources belong to known toolsets func TestAllResourcesHaveValidToolsetID(t *testing.T) { - resources := AllResources(stubTranslation, nil, nil) + resources := AllResources(stubTranslation) validToolsetIDs := GetValidToolsetIDs() for _, res := range resources { @@ -153,7 +152,7 @@ func TestNoDuplicateToolNames(t *testing.T) { // TestNoDuplicateResourceNames ensures all resources have unique names func TestNoDuplicateResourceNames(t *testing.T) { - resources := AllResources(stubTranslation, nil, nil) + resources := AllResources(stubTranslation) seen := make(map[string]bool) for _, res := range resources { diff --git a/pkg/github/toolset_group.go b/pkg/github/toolset_group.go index bca1f7ca4..68db02db6 100644 --- a/pkg/github/toolset_group.go +++ b/pkg/github/toolset_group.go @@ -1,18 +1,19 @@ package github import ( - "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" ) // NewToolsetGroup creates a ToolsetGroup with all available tools, resources, and prompts. -// Tools are self-describing with their toolset metadata embedded. +// Tools, resources, and prompts are self-describing with their toolset metadata embedded. +// This function is stateless - no dependencies are captured. +// Handlers are generated on-demand during registration via RegisterAll(ctx, server, deps). // The "default" keyword in WithToolsets will expand to GetDefaultToolsetIDs(). -func NewToolsetGroup(t translations.TranslationHelperFunc, getClient GetClientFn, getRawClient raw.GetRawClientFn) *toolsets.ToolsetGroup { +func NewToolsetGroup(t translations.TranslationHelperFunc) *toolsets.ToolsetGroup { tsg := toolsets.NewToolsetGroup( AllTools(t), - AllResources(t, getClient, getRawClient), + AllResources(t), AllPrompts(t), ) tsg.SetDefaultToolsetIDs(GetDefaultToolsetIDs()) diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 5b1df37ff..960b7e8a4 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -48,10 +48,18 @@ func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { // ServerTool is defined in server_tool.go +// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. +// This allows resources to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +type ResourceHandlerFunc func(deps any) mcp.ResourceHandler + // ServerResourceTemplate pairs a resource template with its toolset metadata. type ServerResourceTemplate struct { Template mcp.ResourceTemplate - Handler mcp.ResourceHandler + // HandlerFunc generates the handler when given dependencies. + // This allows resources to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc ResourceHandlerFunc // Toolset identifies which toolset this resource belongs to Toolset ToolsetMetadata // FeatureFlagEnable specifies a feature flag that must be enabled for this resource @@ -62,12 +70,26 @@ type ServerResourceTemplate struct { FeatureFlagDisable string } +// HasHandler returns true if this resource has a handler function. +func (sr *ServerResourceTemplate) HasHandler() bool { + return sr.HandlerFunc != nil +} + +// Handler returns a resource handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all resources should have handlers. +func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { + if sr.HandlerFunc == nil { + panic("HandlerFunc is nil for resource: " + sr.Template.Name) + } + return sr.HandlerFunc(deps) +} + // NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. -func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handler mcp.ResourceHandler) ServerResourceTemplate { +func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { return ServerResourceTemplate{ - Template: resourceTemplate, - Handler: handler, - Toolset: toolset, + Template: resourceTemplate, + HandlerFunc: handlerFn, + Toolset: toolset, } } @@ -643,9 +665,9 @@ func (tg *ToolsetGroup) RegisterTools(ctx context.Context, s *mcp.Server, deps a // RegisterResourceTemplates registers all available resource templates with the server. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) RegisterResourceTemplates(ctx context.Context, s *mcp.Server) { +func (tg *ToolsetGroup) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { for _, res := range tg.AvailableResourceTemplates(ctx) { - s.AddResourceTemplate(&res.Template, res.Handler) + s.AddResourceTemplate(&res.Template, res.Handler(deps)) } } @@ -661,7 +683,7 @@ func (tg *ToolsetGroup) RegisterPrompts(ctx context.Context, s *mcp.Server) { // The context is used for feature flag evaluation. func (tg *ToolsetGroup) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { tg.RegisterTools(ctx, s, deps) - tg.RegisterResourceTemplates(ctx, s) + tg.RegisterResourceTemplates(ctx, s, deps) tg.RegisterPrompts(ctx, s) } diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index 34d9ba753..317dafbf6 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -489,8 +489,10 @@ func mockResource(name string, toolsetID string, uriTemplate string) ServerResou Name: name, URITemplate: uriTemplate, }, - func(_ context.Context, _ *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { - return nil, nil + func(_ any) mcp.ResourceHandler { + return func(_ context.Context, _ *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return nil, nil + } }, ) } From b13d9fecb8b4c31a0712de9cc9a03fe0b7ab8541 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 19:29:13 +0100 Subject: [PATCH 06/27] refactor: simplify ForMCPRequest switch cases --- pkg/toolsets/toolsets.go | 45 ++++++++++++---------------------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 960b7e8a4..32105dbe8 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -337,58 +337,39 @@ const ( func (tg *ToolsetGroup) ForMCPRequest(method string, itemName string) *ToolsetGroup { result := tg.copy() - switch method { - case MCPMethodInitialize: - // Capabilities only - no items need to be registered - // The server capabilities (tools, resources, prompts support) are set via ServerOptions + // Helper to clear all item types + clearAll := func() { result.tools = []ServerTool{} result.resourceTemplates = []ServerResourceTemplate{} result.prompts = []ServerPrompt{} + } + switch method { + case MCPMethodInitialize: + clearAll() case MCPMethodToolsList: - // All available tools, but no resources or prompts - result.resourceTemplates = []ServerResourceTemplate{} - result.prompts = []ServerPrompt{} - + result.resourceTemplates, result.prompts = nil, nil case MCPMethodToolsCall: - // Only the specific tool (if found), no resources or prompts - result.resourceTemplates = []ServerResourceTemplate{} - result.prompts = []ServerPrompt{} + result.resourceTemplates, result.prompts = nil, nil if itemName != "" { result.tools = tg.filterToolsByName(itemName) } - case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: - // All available resources, but no tools or prompts - result.tools = []ServerTool{} - result.prompts = []ServerPrompt{} - + result.tools, result.prompts = nil, nil case MCPMethodResourcesRead: - // Only the specific resource template, no tools or prompts - result.tools = []ServerTool{} - result.prompts = []ServerPrompt{} + result.tools, result.prompts = nil, nil if itemName != "" { result.resourceTemplates = tg.filterResourcesByURI(itemName) } - case MCPMethodPromptsList: - // All available prompts, but no tools or resources - result.tools = []ServerTool{} - result.resourceTemplates = []ServerResourceTemplate{} - + result.tools, result.resourceTemplates = nil, nil case MCPMethodPromptsGet: - // Only the specific prompt, no tools or resources - result.tools = []ServerTool{} - result.resourceTemplates = []ServerResourceTemplate{} + result.tools, result.resourceTemplates = nil, nil if itemName != "" { result.prompts = tg.filterPromptsByName(itemName) } - default: - // Unknown method - register nothing - result.tools = []ServerTool{} - result.resourceTemplates = []ServerResourceTemplate{} - result.prompts = []ServerPrompt{} + clearAll() } return result From 74d2c56c3056a375004297cf44da56bb3146b0a8 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 22:30:13 +0100 Subject: [PATCH 07/27] refactor(generate_docs): use strings.Builder and AllTools() iteration - Replace slice joining with strings.Builder for all doc generation - Iterate AllTools() directly instead of ToolsetIDs()/ToolsForToolset() - Removes need for special 'dynamic' toolset handling (no tools = no output) - Context toolset still explicitly handled for custom description - Consistent pattern across generateToolsetsDoc, generateToolsDoc, generateRemoteToolsetsDoc, and generateDeprecatedAliasesTable --- cmd/github-mcp-server/generate_docs.go | 226 +++++++++++++------------ docs/remote-server.md | 1 - 2 files changed, 120 insertions(+), 107 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 1e3f3252b..6a3100ac2 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -107,64 +107,73 @@ func generateRemoteServerDocs(docsPath string) error { } func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { - var lines []string + var buf strings.Builder // Add table header and separator - lines = append(lines, "| Toolset | Description |") - lines = append(lines, "| ----------------------- | ------------------------------------------------------------- |") - - // Add the context toolset row (handled separately in README) - lines = append(lines, "| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |") - - // Get toolset IDs and descriptions - toolsetIDs := tsg.ToolsetIDs() - descriptions := tsg.ToolsetDescriptions() - - // Filter out context and dynamic toolsets (handled separately) - for _, id := range toolsetIDs { - if id != "context" && id != "dynamic" { - description := descriptions[id] - lines = append(lines, fmt.Sprintf("| `%s` | %s |", id, description)) + buf.WriteString("| Toolset | Description |\n") + buf.WriteString("| ----------------------- | ------------------------------------------------------------- |\n") + + // Add the context toolset row with custom description (strongly recommended) + buf.WriteString("| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |\n") + + // AllTools() is sorted by toolset ID then tool name. + // We iterate once, collecting unique toolsets (skipping context which has custom description above). + tools := tsg.AllTools() + var lastToolsetID toolsets.ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastToolsetID { + lastToolsetID = tool.Toolset.ID + // Skip context (handled above with custom description) + if lastToolsetID == "context" { + continue + } + fmt.Fprintf(&buf, "| `%s` | %s |\n", lastToolsetID, tool.Toolset.Description) } } - return strings.Join(lines, "\n") + return strings.TrimSuffix(buf.String(), "\n") } func generateToolsDoc(tsg *toolsets.ToolsetGroup) string { - var sections []string - - // Get toolset IDs (already sorted deterministically) - toolsetIDs := tsg.ToolsetIDs() + // AllTools() returns tools sorted by toolset ID then tool name. + // We iterate once, grouping by toolset as we encounter them. + tools := tsg.AllTools() + if len(tools) == 0 { + return "" + } - for _, toolsetID := range toolsetIDs { - if toolsetID == "dynamic" { // Skip dynamic toolset as it's handled separately - continue - } + var buf strings.Builder + var toolBuf strings.Builder + var currentToolsetID toolsets.ToolsetID + firstSection := true - // Get tools for this toolset (already sorted deterministically) - tools := tsg.ToolsForToolset(toolsetID) - if len(tools) == 0 { - continue + writeSection := func() { + if toolBuf.Len() == 0 { + return } - - // Generate section header - capitalize first letter and replace underscores - sectionName := formatToolsetName(string(toolsetID)) - - var toolDocs []string - for _, serverTool := range tools { - toolDoc := generateToolDoc(serverTool.Tool) - toolDocs = append(toolDocs, toolDoc) + if !firstSection { + buf.WriteString("\n\n") } + firstSection = false + sectionName := formatToolsetName(string(currentToolsetID)) + fmt.Fprintf(&buf, "
\n\n%s\n\n%s\n\n
", sectionName, strings.TrimSuffix(toolBuf.String(), "\n\n")) + toolBuf.Reset() + } - if len(toolDocs) > 0 { - section := fmt.Sprintf("
\n\n%s\n\n%s\n\n
", - sectionName, strings.Join(toolDocs, "\n\n")) - sections = append(sections, section) + for _, tool := range tools { + // When toolset changes, emit the previous section + if tool.Toolset.ID != currentToolsetID { + writeSection() + currentToolsetID = tool.Toolset.ID } + writeToolDoc(&toolBuf, tool.Tool) + toolBuf.WriteString("\n\n") } - return strings.Join(sections, "\n\n") + // Emit the last section + writeSection() + + return buf.String() } func formatToolsetName(name string) string { @@ -191,21 +200,19 @@ func formatToolsetName(name string) string { } } -func generateToolDoc(tool mcp.Tool) string { - var lines []string - +func writeToolDoc(buf *strings.Builder, tool mcp.Tool) { // Tool name only (using annotation name instead of verbose description) - lines = append(lines, fmt.Sprintf("- **%s** - %s", tool.Name, tool.Annotations.Title)) + fmt.Fprintf(buf, "- **%s** - %s\n", tool.Name, tool.Annotations.Title) // Parameters if tool.InputSchema == nil { - lines = append(lines, " - No parameters required") - return strings.Join(lines, "\n") + buf.WriteString(" - No parameters required") + return } schema, ok := tool.InputSchema.(*jsonschema.Schema) if !ok || schema == nil { - lines = append(lines, " - No parameters required") - return strings.Join(lines, "\n") + buf.WriteString(" - No parameters required") + return } if len(schema.Properties) > 0 { @@ -216,7 +223,7 @@ func generateToolDoc(tool mcp.Tool) string { } sort.Strings(paramNames) - for _, propName := range paramNames { + for i, propName := range paramNames { prop := schema.Properties[propName] required := contains(schema.Required, propName) requiredStr := "optional" @@ -224,7 +231,7 @@ func generateToolDoc(tool mcp.Tool) string { requiredStr = "required" } - var typeStr, description string + var typeStr string // Get the type and description switch prop.Type { @@ -238,19 +245,17 @@ func generateToolDoc(tool mcp.Tool) string { typeStr = prop.Type } - description = prop.Description - // Indent any continuation lines in the description to maintain markdown formatting - description = indentMultilineDescription(description, " ") + description := indentMultilineDescription(prop.Description, " ") - paramLine := fmt.Sprintf(" - `%s`: %s (%s, %s)", propName, description, typeStr, requiredStr) - lines = append(lines, paramLine) + fmt.Fprintf(buf, " - `%s`: %s (%s, %s)", propName, description, typeStr, requiredStr) + if i < len(paramNames)-1 { + buf.WriteString("\n") + } } } else { - lines = append(lines, " - No parameters required") + buf.WriteString(" - No parameters required") } - - return strings.Join(lines, "\n") } func contains(slice []string, item string) bool { @@ -265,14 +270,18 @@ func contains(slice []string, item string) bool { // indentMultilineDescription adds the specified indent to all lines after the first line. // This ensures that multi-line descriptions maintain proper markdown list formatting. func indentMultilineDescription(description, indent string) string { - lines := strings.Split(description, "\n") - if len(lines) <= 1 { + if !strings.Contains(description, "\n") { return description } + var buf strings.Builder + lines := strings.Split(description, "\n") + buf.WriteString(lines[0]) for i := 1; i < len(lines); i++ { - lines[i] = indent + lines[i] + buf.WriteString("\n") + buf.WriteString(indent) + buf.WriteString(lines[i]) } - return strings.Join(lines, "\n") + return buf.String() } func replaceSection(content, startMarker, endMarker, newContent string) string { @@ -299,47 +308,49 @@ func generateRemoteToolsetsDoc() string { buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") buf.WriteString("|----------------|--------------------------------------------------|-------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n") - // Get toolset IDs and descriptions - toolsetIDs := tsg.ToolsetIDs() - descriptions := tsg.ToolsetDescriptions() - // Add "all" toolset first (special case) buf.WriteString("| all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) |\n") - // Add individual toolsets - for _, id := range toolsetIDs { - idStr := string(id) - if idStr == "context" || idStr == "dynamic" { // Skip context and dynamic toolsets as they're handled separately - continue - } + // AllTools() is sorted by toolset ID then tool name. + // We iterate once, collecting unique toolsets (skipping context which is handled separately). + tools := tsg.AllTools() + var lastToolsetID toolsets.ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastToolsetID { + lastToolsetID = tool.Toolset.ID + idStr := string(lastToolsetID) + // Skip context toolset (handled separately) + if idStr == "context" { + continue + } - description := descriptions[id] - formattedName := formatToolsetName(idStr) - apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) - readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) - - // Create install config JSON (URL encoded) - installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) - readonlyConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, readonlyURL)) - - // Fix URL encoding to use %20 instead of + for spaces - installConfig = strings.ReplaceAll(installConfig, "+", "%20") - readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") - - installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) - readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) - - buf.WriteString(fmt.Sprintf("| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", - formattedName, - description, - apiURL, - installLink, - fmt.Sprintf("[read-only](%s)", readonlyURL), - readonlyInstallLink, - )) + formattedName := formatToolsetName(idStr) + apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) + readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) + + // Create install config JSON (URL encoded) + installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) + readonlyConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, readonlyURL)) + + // Fix URL encoding to use %20 instead of + for spaces + installConfig = strings.ReplaceAll(installConfig, "+", "%20") + readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") + + installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) + readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) + + fmt.Fprintf(&buf, "| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", + formattedName, + tool.Toolset.Description, + apiURL, + installLink, + fmt.Sprintf("[read-only](%s)", readonlyURL), + readonlyInstallLink, + ) + } } - return buf.String() + return strings.TrimSuffix(buf.String(), "\n") } func generateDeprecatedAliasesDocs(docsPath string) error { @@ -366,15 +377,15 @@ func generateDeprecatedAliasesDocs(docsPath string) error { } func generateDeprecatedAliasesTable() string { - var lines []string + var buf strings.Builder // Add table header - lines = append(lines, "| Old Name | New Name |") - lines = append(lines, "|----------|----------|") + buf.WriteString("| Old Name | New Name |\n") + buf.WriteString("|----------|----------|\n") aliases := github.DeprecatedToolAliases if len(aliases) == 0 { - lines = append(lines, "| *(none currently)* | |") + buf.WriteString("| *(none currently)* | |") } else { // Sort keys for deterministic output var oldNames []string @@ -383,11 +394,14 @@ func generateDeprecatedAliasesTable() string { } sort.Strings(oldNames) - for _, oldName := range oldNames { + for i, oldName := range oldNames { newName := aliases[oldName] - lines = append(lines, fmt.Sprintf("| `%s` | `%s` |", oldName, newName)) + fmt.Fprintf(&buf, "| `%s` | `%s` |", oldName, newName) + if i < len(oldNames)-1 { + buf.WriteString("\n") + } } } - return strings.Join(lines, "\n") + return buf.String() } diff --git a/docs/remote-server.md b/docs/remote-server.md index ffdf526a4..53fe36127 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -37,7 +37,6 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Security Advisories | Security advisories related tools | https://api.githubcopilot.com/mcp/x/security_advisories | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/security_advisories/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%2Freadonly%22%7D) | | Stargazers | GitHub Stargazers related tools | https://api.githubcopilot.com/mcp/x/stargazers | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/stargazers/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%2Freadonly%22%7D) | | Users | GitHub User related tools | https://api.githubcopilot.com/mcp/x/users | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/users/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%2Freadonly%22%7D) | - ### Additional _Remote_ Server Toolsets From 9bcf526d726e061ff374009999bc8e24daa60292 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 22:42:50 +0100 Subject: [PATCH 08/27] feat(toolsets): add AvailableToolsets() with exclude filter - Add AvailableToolsets() method that returns toolsets with actual tools - Support variadic exclude parameter for filtering out specific toolsets - Simplifies doc generation by removing manual skip logic - Naturally excludes empty toolsets (like 'dynamic') without special cases --- cmd/github-mcp-server/generate_docs.go | 82 ++++++++++---------------- pkg/toolsets/toolsets.go | 29 +++++++++ 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 6a3100ac2..a473787c6 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -116,19 +116,10 @@ func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { // Add the context toolset row with custom description (strongly recommended) buf.WriteString("| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |\n") - // AllTools() is sorted by toolset ID then tool name. - // We iterate once, collecting unique toolsets (skipping context which has custom description above). - tools := tsg.AllTools() - var lastToolsetID toolsets.ToolsetID - for _, tool := range tools { - if tool.Toolset.ID != lastToolsetID { - lastToolsetID = tool.Toolset.ID - // Skip context (handled above with custom description) - if lastToolsetID == "context" { - continue - } - fmt.Fprintf(&buf, "| `%s` | %s |\n", lastToolsetID, tool.Toolset.Description) - } + // AvailableToolsets() returns toolsets that have tools, sorted by ID + // Exclude context (custom description above) and dynamic (internal only) + for _, ts := range tsg.AvailableToolsets("context", "dynamic") { + fmt.Fprintf(&buf, "| `%s` | %s |\n", ts.ID, ts.Description) } return strings.TrimSuffix(buf.String(), "\n") @@ -311,43 +302,34 @@ func generateRemoteToolsetsDoc() string { // Add "all" toolset first (special case) buf.WriteString("| all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) |\n") - // AllTools() is sorted by toolset ID then tool name. - // We iterate once, collecting unique toolsets (skipping context which is handled separately). - tools := tsg.AllTools() - var lastToolsetID toolsets.ToolsetID - for _, tool := range tools { - if tool.Toolset.ID != lastToolsetID { - lastToolsetID = tool.Toolset.ID - idStr := string(lastToolsetID) - // Skip context toolset (handled separately) - if idStr == "context" { - continue - } - - formattedName := formatToolsetName(idStr) - apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) - readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) - - // Create install config JSON (URL encoded) - installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) - readonlyConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, readonlyURL)) - - // Fix URL encoding to use %20 instead of + for spaces - installConfig = strings.ReplaceAll(installConfig, "+", "%20") - readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") - - installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) - readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) - - fmt.Fprintf(&buf, "| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", - formattedName, - tool.Toolset.Description, - apiURL, - installLink, - fmt.Sprintf("[read-only](%s)", readonlyURL), - readonlyInstallLink, - ) - } + // AvailableToolsets() returns toolsets that have tools, sorted by ID + // Exclude context (handled separately) and dynamic (internal only) + for _, ts := range tsg.AvailableToolsets("context", "dynamic") { + idStr := string(ts.ID) + + formattedName := formatToolsetName(idStr) + apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) + readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) + + // Create install config JSON (URL encoded) + installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) + readonlyConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, readonlyURL)) + + // Fix URL encoding to use %20 instead of + for spaces + installConfig = strings.ReplaceAll(installConfig, "+", "%20") + readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") + + installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) + readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) + + fmt.Fprintf(&buf, "| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", + formattedName, + ts.Description, + apiURL, + installLink, + fmt.Sprintf("[read-only](%s)", readonlyURL), + readonlyInstallLink, + ) } return strings.TrimSuffix(buf.String(), "\n") diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 32105dbe8..82896510e 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -767,3 +767,32 @@ func (tg *ToolsetGroup) AllTools() []ServerTool { return result } + +// AvailableToolsets returns the unique toolsets that have tools, in sorted order. +// This is the ordered intersection of toolsets with reality - only toolsets that +// actually contain tools are returned, sorted by toolset ID. +// Optional exclude parameter filters out specific toolset IDs from the result. +func (tg *ToolsetGroup) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { + tools := tg.AllTools() + if len(tools) == 0 { + return nil + } + + // Build exclude set for O(1) lookup + excludeSet := make(map[ToolsetID]bool, len(exclude)) + for _, id := range exclude { + excludeSet[id] = true + } + + var result []ToolsetMetadata + var lastID ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastID { + lastID = tool.Toolset.ID + if !excludeSet[lastID] { + result = append(result, tool.Toolset) + } + } + } + return result +} From 2142b025c8386866552a731eb1b7e2f4bc8f59ec Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 14 Dec 2025 22:47:13 +0100 Subject: [PATCH 09/27] refactor(generate_docs): hoist success logging to generateAllDocs --- cmd/github-mcp-server/generate_docs.go | 81 ++++++++++++++------------ 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index a473787c6..30a5667b3 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -4,7 +4,6 @@ import ( "fmt" "net/url" "os" - "regexp" "sort" "strings" @@ -30,18 +29,20 @@ func init() { } func generateAllDocs() error { - if err := generateReadmeDocs("README.md"); err != nil { - return fmt.Errorf("failed to generate README docs: %w", err) - } - - if err := generateRemoteServerDocs("docs/remote-server.md"); err != nil { - return fmt.Errorf("failed to generate remote-server docs: %w", err) - } - - if err := generateDeprecatedAliasesDocs("docs/deprecated-tool-aliases.md"); err != nil { - return fmt.Errorf("failed to generate deprecated aliases docs: %w", err) + for _, doc := range []struct { + path string + fn func(string) error + }{ + // File to edit, function to generate its docs + {"README.md", generateReadmeDocs}, + {"docs/remote-server.md", generateRemoteServerDocs}, + {"docs/deprecated-tool-aliases.md", generateDeprecatedAliasesDocs}, + } { + if err := doc.fn(doc.path); err != nil { + return fmt.Errorf("failed to generate docs for %s: %w", doc.path, err) + } + fmt.Printf("Successfully updated %s with automated documentation\n", doc.path) } - return nil } @@ -66,10 +67,16 @@ func generateReadmeDocs(readmePath string) error { } // Replace toolsets section - updatedContent := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + updatedContent, err := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + if err != nil { + return err + } // Replace tools section - updatedContent = replaceSection(updatedContent, "START AUTOMATED TOOLS", "END AUTOMATED TOOLS", toolsDoc) + updatedContent, err = replaceSection(updatedContent, "START AUTOMATED TOOLS", "END AUTOMATED TOOLS", toolsDoc) + if err != nil { + return err + } // Write back to file err = os.WriteFile(readmePath, []byte(updatedContent), 0600) @@ -77,7 +84,6 @@ func generateReadmeDocs(readmePath string) error { return fmt.Errorf("failed to write README.md: %w", err) } - fmt.Println("Successfully updated README.md with automated documentation") return nil } @@ -90,20 +96,12 @@ func generateRemoteServerDocs(docsPath string) error { toolsetsDoc := generateRemoteToolsetsDoc() // Replace content between markers - startMarker := "" - endMarker := "" - - contentStr := string(content) - startIndex := strings.Index(contentStr, startMarker) - endIndex := strings.Index(contentStr, endMarker) - - if startIndex == -1 || endIndex == -1 { - return fmt.Errorf("automation markers not found in %s", docsPath) + updatedContent, err := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + if err != nil { + return err } - newContent := contentStr[:startIndex] + startMarker + "\n" + toolsetsDoc + "\n" + endMarker + contentStr[endIndex+len(endMarker):] - - return os.WriteFile(docsPath, []byte(newContent), 0600) //#nosec G306 + return os.WriteFile(docsPath, []byte(updatedContent), 0600) //#nosec G306 } func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { @@ -275,15 +273,24 @@ func indentMultilineDescription(description, indent string) string { return buf.String() } -func replaceSection(content, startMarker, endMarker, newContent string) string { - startPattern := fmt.Sprintf(``, regexp.QuoteMeta(startMarker)) - endPattern := fmt.Sprintf(``, regexp.QuoteMeta(endMarker)) - - re := regexp.MustCompile(fmt.Sprintf(`(?s)%s.*?%s`, startPattern, endPattern)) +func replaceSection(content, startMarker, endMarker, newContent string) (string, error) { + start := fmt.Sprintf("", startMarker) + end := fmt.Sprintf("", endMarker) - replacement := fmt.Sprintf("\n%s\n", startMarker, newContent, endMarker) + startIdx := strings.Index(content, start) + endIdx := strings.Index(content, end) + if startIdx == -1 || endIdx == -1 { + return "", fmt.Errorf("markers not found: %s / %s", start, end) + } - return re.ReplaceAllString(content, replacement) + var buf strings.Builder + buf.WriteString(content[:startIdx]) + buf.WriteString(start) + buf.WriteString("\n") + buf.WriteString(newContent) + buf.WriteString("\n") + buf.WriteString(content[endIdx:]) + return buf.String(), nil } func generateRemoteToolsetsDoc() string { @@ -346,7 +353,10 @@ func generateDeprecatedAliasesDocs(docsPath string) error { aliasesDoc := generateDeprecatedAliasesTable() // Replace content between markers - updatedContent := replaceSection(string(content), "START AUTOMATED ALIASES", "END AUTOMATED ALIASES", aliasesDoc) + updatedContent, err := replaceSection(string(content), "START AUTOMATED ALIASES", "END AUTOMATED ALIASES", aliasesDoc) + if err != nil { + return err + } // Write back to file err = os.WriteFile(docsPath, []byte(updatedContent), 0600) @@ -354,7 +364,6 @@ func generateDeprecatedAliasesDocs(docsPath string) error { return fmt.Errorf("failed to write deprecated aliases docs: %w", err) } - fmt.Println("Successfully updated docs/deprecated-tool-aliases.md with automated documentation") return nil } From 6b6a87455e3bb45da241c9919521fe69c4e7956a Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 00:20:16 +0100 Subject: [PATCH 10/27] refactor: consolidate toolset validation into ToolsetGroup - Add Default field to ToolsetMetadata and derive defaults from metadata - Move toolset validation into WithToolsets (trims whitespace, dedupes, tracks unrecognized) - Add UnrecognizedToolsets() method for warning about typos - Add DefaultToolsetIDs() method to derive defaults from metadata - Remove redundant functions: CleanToolsets, GetValidToolsetIDs, AvailableToolsets, GetDefaultToolsetIDs - Update DynamicTools to take ToolsetGroup for schema enum generation - Add stubTranslator for cases needing ToolsetGroup without translations This eliminates hardcoded toolset lists - everything is now derived from the actual registered tools and their metadata. --- cmd/github-mcp-server/generate_docs.go | 18 +- e2e/e2e_test.go | 2 +- internal/ghmcp/server.go | 28 +- pkg/github/dynamic_tools.go | 56 ++-- pkg/github/tools.go | 104 ++----- pkg/github/tools_test.go | 129 -------- pkg/github/tools_validation_test.go | 51 ---- pkg/github/toolset_group.go | 17 +- pkg/toolsets/server_tool.go | 2 + pkg/toolsets/toolsets.go | 406 +++++++++++++++---------- pkg/toolsets/toolsets_test.go | 291 +++++++++++++++--- 11 files changed, 556 insertions(+), 548 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 30a5667b3..ddfcd10ba 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -51,13 +51,13 @@ func generateReadmeDocs(readmePath string) error { t, _ := translations.TranslationHelper() // Create toolset group - stateless, no dependencies needed for doc generation - tsg := github.NewToolsetGroup(t) + r := github.NewRegistry(t) // Generate toolsets documentation - toolsetsDoc := generateToolsetsDoc(tsg) + toolsetsDoc := generateToolsetsDoc(r) // Generate tools documentation - toolsDoc := generateToolsDoc(tsg) + toolsDoc := generateToolsDoc(r) // Read the current README.md // #nosec G304 - readmePath is controlled by command line flag, not user input @@ -104,7 +104,7 @@ func generateRemoteServerDocs(docsPath string) error { return os.WriteFile(docsPath, []byte(updatedContent), 0600) //#nosec G306 } -func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { +func generateToolsetsDoc(r *toolsets.Registry) string { var buf strings.Builder // Add table header and separator @@ -116,17 +116,17 @@ func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { // AvailableToolsets() returns toolsets that have tools, sorted by ID // Exclude context (custom description above) and dynamic (internal only) - for _, ts := range tsg.AvailableToolsets("context", "dynamic") { + for _, ts := range r.AvailableToolsets("context", "dynamic") { fmt.Fprintf(&buf, "| `%s` | %s |\n", ts.ID, ts.Description) } return strings.TrimSuffix(buf.String(), "\n") } -func generateToolsDoc(tsg *toolsets.ToolsetGroup) string { +func generateToolsDoc(r *toolsets.Registry) string { // AllTools() returns tools sorted by toolset ID then tool name. // We iterate once, grouping by toolset as we encounter them. - tools := tsg.AllTools() + tools := r.AllTools() if len(tools) == 0 { return "" } @@ -300,7 +300,7 @@ func generateRemoteToolsetsDoc() string { t, _ := translations.TranslationHelper() // Create toolset group - stateless - tsg := github.NewToolsetGroup(t) + r := github.NewRegistry(t) // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") @@ -311,7 +311,7 @@ func generateRemoteToolsetsDoc() string { // AvailableToolsets() returns toolsets that have tools, sorted by ID // Exclude context (handled separately) and dynamic (internal only) - for _, ts := range tsg.AvailableToolsets("context", "dynamic") { + for _, ts := range r.AvailableToolsets("context", "dynamic") { idStr := string(ts.ID) formattedName := formatToolsetName(idStr) diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 5f67fb84c..ad9ebb190 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -178,7 +178,7 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcp.ClientSession { // so that there is a shared setup mechanism, but let's wait till we feel more friction. enabledToolsets := opts.enabledToolsets if enabledToolsets == nil { - enabledToolsets = github.GetDefaultToolsetIDs() + enabledToolsets = github.NewRegistry(translations.NullTranslationHelper).DefaultToolsetIDs() } ghServer, err := ghmcp.NewMCPServer(ghmcp.MCPServerConfig{ diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 1dec59381..67fcad4a7 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -109,12 +109,7 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { // - explicit list means "use these toolsets" var enabledToolsets []string if cfg.EnabledToolsets != nil { - // Clean up explicitly passed toolsets (removes duplicates, whitespace) - var invalidToolsets []string - enabledToolsets, invalidToolsets = github.CleanToolsets(cfg.EnabledToolsets) - if len(invalidToolsets) > 0 { - fmt.Fprintf(os.Stderr, "Invalid toolsets ignored: %s\n", strings.Join(invalidToolsets, ", ")) - } + enabledToolsets = cfg.EnabledToolsets } else if cfg.DynamicToolsets { // Dynamic mode with no toolsets specified: start with no toolsets enabled // so users can enable them on demand via the dynamic tools @@ -163,7 +158,7 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { } // Create toolset group with all tools, resources, and prompts (stateless) - tsg := github.NewToolsetGroup(cfg.Translator) + r := github.NewRegistry(cfg.Translator) // Clean tool names (WithTools will resolve any deprecated aliases) enabledTools := github.CleanTools(cfg.EnabledTools) @@ -174,16 +169,21 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { // - WithToolsets: nil=defaults, empty=none, handles "all"/"default" keywords // - WithTools: additional tools that bypass toolset filtering (additive, resolves aliases) // - WithFeatureChecker: filters based on feature flags - filteredTsg := tsg. + filteredReg := r. WithDeprecatedToolAliases(github.DeprecatedToolAliases). WithReadOnly(cfg.ReadOnly). WithToolsets(enabledToolsets). WithTools(enabledTools). WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) + // Warn about unrecognized toolset names (likely typos) + if unrecognized := filteredReg.UnrecognizedToolsets(); len(unrecognized) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) + } + // Register all mcp functionality with the server // Use background context for local server (no per-request actor context) - filteredTsg.RegisterAll(context.Background(), ghServer, deps) + filteredReg.RegisterAll(context.Background(), ghServer, deps) // Register dynamic toolset management if configured // Dynamic tools get access to the filtered toolset group which tracks enabled state. @@ -191,12 +191,12 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { // so dynamic tools can enable any toolset at runtime. if cfg.DynamicToolsets { dynamicDeps := github.DynamicToolDependencies{ - Server: ghServer, - ToolsetGroup: filteredTsg, - ToolDeps: deps, - T: cfg.Translator, + Server: ghServer, + Registry: filteredReg, + ToolDeps: deps, + T: cfg.Translator, } - dynamicTools := github.DynamicTools() + dynamicTools := github.DynamicTools(filteredReg) for _, tool := range dynamicTools { tool.RegisterFunc(ghServer, dynamicDeps) } diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index cc44e85f5..93c24a07b 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -13,13 +13,13 @@ import ( ) // DynamicToolDependencies contains dependencies for dynamic toolset management tools. -// It includes the managed ToolsetGroup, the server for registration, and the deps +// It includes the managed Registry, the server for registration, and the deps // that will be passed to tools when they are dynamically enabled. type DynamicToolDependencies struct { // Server is the MCP server to register tools with Server *mcp.Server - // ToolsetGroup contains all available tools that can be enabled dynamically - ToolsetGroup *toolsets.ToolsetGroup + // Registry contains all available tools that can be enabled dynamically + Registry *toolsets.Registry // ToolDeps are the dependencies passed to tools when they are registered ToolDeps any // T is the translation helper function @@ -33,20 +33,9 @@ func NewDynamicTool(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler fun }) } -// AllToolsetIDsEnum returns all available toolset IDs as an enum for JSON Schema. -func AllToolsetIDsEnum() []any { - toolsets := AvailableToolsets() - result := make([]any, len(toolsets)) - for i, ts := range toolsets { - result[i] = ts.ID - } - return result -} - -// ToolsetEnum returns the list of toolset IDs as an enum for JSON Schema. -// Deprecated: Use AllToolsetIDsEnum() instead. -func ToolsetEnum(toolsetGroup *toolsets.ToolsetGroup) []any { - toolsetIDs := toolsetGroup.ToolsetIDs() +// toolsetIDsEnum returns the list of toolset IDs as an enum for JSON Schema. +func toolsetIDsEnum(r *toolsets.Registry) []any { + toolsetIDs := r.ToolsetIDs() result := make([]any, len(toolsetIDs)) for i, id := range toolsetIDs { result[i] = id @@ -56,16 +45,17 @@ func ToolsetEnum(toolsetGroup *toolsets.ToolsetGroup) []any { // DynamicTools returns the tools for dynamic toolset management. // These tools allow runtime discovery and enablement of toolsets. -func DynamicTools() []toolsets.ServerTool { +// The r parameter provides the available toolset IDs for JSON Schema enums. +func DynamicTools(r *toolsets.Registry) []toolsets.ServerTool { return []toolsets.ServerTool{ ListAvailableToolsets(), - GetToolsetsTools(), - EnableToolset(), + GetToolsetsTools(r), + EnableToolset(r), } } // EnableToolset creates a tool that enables a toolset at runtime. -func EnableToolset() toolsets.ServerTool { +func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -81,7 +71,7 @@ func EnableToolset() toolsets.ServerTool { "toolset": { Type: "string", Description: "The name of the toolset to enable", - Enum: AllToolsetIDsEnum(), + Enum: toolsetIDsEnum(r), }, }, Required: []string{"toolset"}, @@ -96,19 +86,19 @@ func EnableToolset() toolsets.ServerTool { toolsetID := toolsets.ToolsetID(toolsetName) - if !deps.ToolsetGroup.HasToolset(toolsetID) { + if !deps.Registry.HasToolset(toolsetID) { return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil } - if deps.ToolsetGroup.IsToolsetEnabled(toolsetID) { + if deps.Registry.IsToolsetEnabled(toolsetID) { return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil } // Mark the toolset as enabled so IsToolsetEnabled returns true - deps.ToolsetGroup.EnableToolset(toolsetID) + deps.Registry.EnableToolset(toolsetID) // Get tools for this toolset and register them with the managed deps - toolsForToolset := deps.ToolsetGroup.ToolsForToolset(toolsetID) + toolsForToolset := deps.Registry.ToolsForToolset(toolsetID) for _, st := range toolsForToolset { st.RegisterFunc(deps.Server, deps.ToolDeps) } @@ -137,8 +127,8 @@ func ListAvailableToolsets() toolsets.ServerTool { }, func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { return func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - toolsetIDs := deps.ToolsetGroup.ToolsetIDs() - descriptions := deps.ToolsetGroup.ToolsetDescriptions() + toolsetIDs := deps.Registry.ToolsetIDs() + descriptions := deps.Registry.ToolsetDescriptions() payload := make([]map[string]string, 0, len(toolsetIDs)) for _, id := range toolsetIDs { @@ -146,7 +136,7 @@ func ListAvailableToolsets() toolsets.ServerTool { "name": string(id), "description": descriptions[id], "can_enable": "true", - "currently_enabled": fmt.Sprintf("%t", deps.ToolsetGroup.IsToolsetEnabled(id)), + "currently_enabled": fmt.Sprintf("%t", deps.Registry.IsToolsetEnabled(id)), } payload = append(payload, t) } @@ -163,7 +153,7 @@ func ListAvailableToolsets() toolsets.ServerTool { } // GetToolsetsTools creates a tool that lists all tools in a specific toolset. -func GetToolsetsTools() toolsets.ServerTool { +func GetToolsetsTools(r *toolsets.Registry) toolsets.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -179,7 +169,7 @@ func GetToolsetsTools() toolsets.ServerTool { "toolset": { Type: "string", Description: "The name of the toolset you want to get the tools for", - Enum: AllToolsetIDsEnum(), + Enum: toolsetIDsEnum(r), }, }, Required: []string{"toolset"}, @@ -194,12 +184,12 @@ func GetToolsetsTools() toolsets.ServerTool { toolsetID := toolsets.ToolsetID(toolsetName) - if !deps.ToolsetGroup.HasToolset(toolsetID) { + if !deps.Registry.HasToolset(toolsetID) { return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil } // Get all tools for this toolset (ignoring current filters for discovery) - toolsInToolset := deps.ToolsetGroup.ToolsForToolset(toolsetID) + toolsInToolset := deps.Registry.ToolsForToolset(toolsetID) payload := make([]map[string]string, 0, len(toolsInToolset)) for _, st := range toolsInToolset { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index f39ae43c1..dd2ad4ff4 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -28,10 +28,12 @@ var ( ToolsetMetadataContext = toolsets.ToolsetMetadata{ ID: "context", Description: "Tools that provide context about the current user and GitHub context you are operating in", + Default: true, } ToolsetMetadataRepos = toolsets.ToolsetMetadata{ ID: "repos", Description: "GitHub Repository related tools", + Default: true, } ToolsetMetadataGit = toolsets.ToolsetMetadata{ ID: "git", @@ -40,14 +42,17 @@ var ( ToolsetMetadataIssues = toolsets.ToolsetMetadata{ ID: "issues", Description: "GitHub Issues related tools", + Default: true, } ToolsetMetadataPullRequests = toolsets.ToolsetMetadata{ ID: "pull_requests", Description: "GitHub Pull Request related tools", + Default: true, } ToolsetMetadataUsers = toolsets.ToolsetMetadata{ ID: "users", Description: "GitHub User related tools", + Default: true, } ToolsetMetadataOrgs = toolsets.ToolsetMetadata{ ID: "orgs", @@ -107,53 +112,6 @@ var ( } ) -func AvailableToolsets() []toolsets.ToolsetMetadata { - return []toolsets.ToolsetMetadata{ - ToolsetMetadataContext, - ToolsetMetadataRepos, - ToolsetMetadataGit, - ToolsetMetadataIssues, - ToolsetMetadataPullRequests, - ToolsetMetadataUsers, - ToolsetMetadataOrgs, - ToolsetMetadataActions, - ToolsetMetadataCodeSecurity, - ToolsetMetadataSecretProtection, - ToolsetMetadataDependabot, - ToolsetMetadataNotifications, - ToolsetMetadataExperiments, - ToolsetMetadataDiscussions, - ToolsetMetadataGists, - ToolsetMetadataSecurityAdvisories, - ToolsetMetadataProjects, - ToolsetMetadataStargazers, - ToolsetMetadataDynamic, - ToolsetLabels, - } -} - -// GetValidToolsetIDs returns a map of all valid toolset IDs for quick lookup -func GetValidToolsetIDs() map[toolsets.ToolsetID]bool { - validIDs := make(map[toolsets.ToolsetID]bool) - for _, toolset := range AvailableToolsets() { - validIDs[toolset.ID] = true - } - // Add special keywords - validIDs[ToolsetMetadataAll.ID] = true - validIDs[ToolsetMetadataDefault.ID] = true - return validIDs -} - -func GetDefaultToolsetIDs() []toolsets.ToolsetID { - return []toolsets.ToolsetID{ - ToolsetMetadataContext.ID, - ToolsetMetadataRepos.ID, - ToolsetMetadataIssues.ID, - ToolsetMetadataPullRequests.ID, - ToolsetMetadataUsers.ID, - } -} - // AllTools returns all tools with their embedded toolset metadata. // Tool functions return ServerTool directly with toolset info. func AllTools(t translations.TranslationHelperFunc) []toolsets.ServerTool { @@ -304,16 +262,19 @@ func ToStringPtr(s string) *string { // GenerateToolsetsHelp generates the help text for the toolsets flag func GenerateToolsetsHelp() string { - // Format default tools - defaultIDs := GetDefaultToolsetIDs() + // Get toolset group to derive defaults and available toolsets + r := NewRegistry(stubTranslator) + + // Format default tools from metadata + defaultIDs := r.DefaultToolsetIDs() defaultStrings := make([]string, len(defaultIDs)) for i, id := range defaultIDs { defaultStrings[i] = string(id) } defaultTools := strings.Join(defaultStrings, ", ") - // Format available tools with line breaks for better readability - allToolsets := AvailableToolsets() + // Get all available toolsets (excludes context and dynamic for display) + allToolsets := r.AvailableToolsets("context", "dynamic") var availableToolsLines []string const maxLineLength = 70 currentLine := "" @@ -349,6 +310,10 @@ func GenerateToolsetsHelp() string { return toolsetsHelp } +// stubTranslator is a passthrough translator for cases where we need a Registry +// but don't need actual translations (e.g., getting toolset IDs for CLI help). +func stubTranslator(_, fallback string) string { return fallback } + // AddDefaultToolset removes the default toolset and expands it to the actual default toolset IDs func AddDefaultToolset(result []string) []string { hasDefault := false @@ -367,43 +332,16 @@ func AddDefaultToolset(result []string) []string { result = RemoveToolset(result, string(ToolsetMetadataDefault.ID)) - for _, defaultToolset := range GetDefaultToolsetIDs() { - if !seen[string(defaultToolset)] { - result = append(result, string(defaultToolset)) + // Get default toolset IDs from the Registry + r := NewRegistry(stubTranslator) + for _, id := range r.DefaultToolsetIDs() { + if !seen[string(id)] { + result = append(result, string(id)) } } return result } -// cleanToolsets cleans and handles special toolset keywords: -// - Duplicates are removed from the result -// - Removes whitespaces -// - Validates toolset names and returns invalid ones separately - for warning reporting -// Returns: (toolsets, invalidToolsets) -func CleanToolsets(enabledToolsets []string) ([]string, []string) { - seen := make(map[string]bool) - result := make([]string, 0, len(enabledToolsets)) - invalid := make([]string, 0) - validIDs := GetValidToolsetIDs() - - // Add non-default toolsets, removing duplicates and trimming whitespace - for _, toolset := range enabledToolsets { - trimmed := strings.TrimSpace(toolset) - if trimmed == "" { - continue - } - if !seen[trimmed] { - seen[trimmed] = true - result = append(result, trimmed) - if !validIDs[toolsets.ToolsetID(trimmed)] { - invalid = append(invalid, trimmed) - } - } - } - - return result, invalid -} - func RemoveToolset(tools []string, toRemove string) []string { result := make([]string, 0, len(tools)) for _, tool := range tools { diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go index 45c1e746f..4e6d91980 100644 --- a/pkg/github/tools_test.go +++ b/pkg/github/tools_test.go @@ -7,135 +7,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestCleanToolsets(t *testing.T) { - tests := []struct { - name string - input []string - expected []string - expectedInvalid []string - }{ - { - name: "empty slice", - input: []string{}, - expected: []string{}, - }, - { - name: "nil input slice", - input: nil, - expected: []string{}, - }, - // CleanToolsets only cleans - it does NOT filter out special keywords - { - name: "default keyword preserved", - input: []string{"default"}, - expected: []string{"default"}, - }, - { - name: "default with additional toolsets", - input: []string{"default", "actions", "gists"}, - expected: []string{"default", "actions", "gists"}, - }, - { - name: "all keyword preserved", - input: []string{"all", "actions"}, - expected: []string{"all", "actions"}, - }, - { - name: "no special keywords", - input: []string{"actions", "gists", "notifications"}, - expected: []string{"actions", "gists", "notifications"}, - }, - { - name: "duplicate toolsets without special keywords", - input: []string{"actions", "gists", "actions"}, - expected: []string{"actions", "gists"}, - }, - { - name: "duplicate toolsets with default", - input: []string{"context", "repos", "issues", "pull_requests", "users", "default"}, - expected: []string{"context", "repos", "issues", "pull_requests", "users", "default"}, - }, - { - name: "default appears multiple times - duplicates removed", - input: []string{"default", "actions", "default", "gists", "default"}, - expected: []string{"default", "actions", "gists"}, - }, - // Whitespace test cases - { - name: "whitespace check - leading and trailing whitespace on regular toolsets", - input: []string{" actions ", " gists ", "notifications"}, - expected: []string{"actions", "gists", "notifications"}, - }, - { - name: "whitespace check - default toolset with whitespace", - input: []string{" actions ", " default ", "notifications"}, - expected: []string{"actions", "default", "notifications"}, - }, - { - name: "whitespace check - all toolset with whitespace", - input: []string{" all ", " actions "}, - expected: []string{"all", "actions"}, - }, - // Invalid toolset test cases - { - name: "mix of valid and invalid toolsets", - input: []string{"actions", "invalid_toolset", "gists", "typo_repo"}, - expected: []string{"actions", "invalid_toolset", "gists", "typo_repo"}, - expectedInvalid: []string{"invalid_toolset", "typo_repo"}, - }, - { - name: "invalid with whitespace", - input: []string{" invalid_tool ", " actions ", " typo_gist "}, - expected: []string{"invalid_tool", "actions", "typo_gist"}, - expectedInvalid: []string{"invalid_tool", "typo_gist"}, - }, - { - name: "empty string in toolsets", - input: []string{"", "actions", " ", "gists"}, - expected: []string{"actions", "gists"}, - expectedInvalid: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, invalid := CleanToolsets(tt.input) - - require.Len(t, result, len(tt.expected), "result length should match expected length") - - if tt.expectedInvalid == nil { - tt.expectedInvalid = []string{} - } - require.Len(t, invalid, len(tt.expectedInvalid), "invalid length should match expected invalid length") - - resultMap := make(map[string]bool) - for _, toolset := range result { - resultMap[toolset] = true - } - - expectedMap := make(map[string]bool) - for _, toolset := range tt.expected { - expectedMap[toolset] = true - } - - invalidMap := make(map[string]bool) - for _, toolset := range invalid { - invalidMap[toolset] = true - } - - expectedInvalidMap := make(map[string]bool) - for _, toolset := range tt.expectedInvalid { - expectedInvalidMap[toolset] = true - } - - assert.Equal(t, expectedMap, resultMap, "result should contain all expected toolsets without duplicates") - assert.Equal(t, expectedInvalidMap, invalidMap, "invalid should contain all expected invalid toolsets") - - assert.Len(t, resultMap, len(result), "result should not contain duplicates") - }) - } -} - func TestAddDefaultToolset(t *testing.T) { tests := []struct { name string diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go index ae79c455a..d53243b42 100644 --- a/pkg/github/tools_validation_test.go +++ b/pkg/github/tools_validation_test.go @@ -42,20 +42,6 @@ func TestAllToolsHaveRequiredMetadata(t *testing.T) { } } -// TestAllToolsHaveValidToolsetID validates that all tools belong to known toolsets -func TestAllToolsHaveValidToolsetID(t *testing.T) { - tools := AllTools(stubTranslation) - validToolsetIDs := GetValidToolsetIDs() - - for _, tool := range tools { - t.Run(tool.Tool.Name, func(t *testing.T) { - assert.True(t, validToolsetIDs[tool.Toolset.ID], - "Tool %q has invalid Toolset.ID %q - must be one of the defined toolsets", - tool.Tool.Name, tool.Toolset.ID) - }) - } -} - // TestAllResourcesHaveRequiredMetadata validates that all resources have mandatory metadata func TestAllResourcesHaveRequiredMetadata(t *testing.T) { // Resources are now stateless - no client functions needed @@ -95,32 +81,6 @@ func TestAllPromptsHaveRequiredMetadata(t *testing.T) { } } -// TestAllResourcesHaveValidToolsetID validates that all resources belong to known toolsets -func TestAllResourcesHaveValidToolsetID(t *testing.T) { - resources := AllResources(stubTranslation) - validToolsetIDs := GetValidToolsetIDs() - - for _, res := range resources { - t.Run(res.Template.Name, func(t *testing.T) { - assert.True(t, validToolsetIDs[res.Toolset.ID], - "Resource %q has invalid Toolset.ID %q", res.Template.Name, res.Toolset.ID) - }) - } -} - -// TestAllPromptsHaveValidToolsetID validates that all prompts belong to known toolsets -func TestAllPromptsHaveValidToolsetID(t *testing.T) { - prompts := AllPrompts(stubTranslation) - validToolsetIDs := GetValidToolsetIDs() - - for _, prompt := range prompts { - t.Run(prompt.Prompt.Name, func(t *testing.T) { - assert.True(t, validToolsetIDs[prompt.Toolset.ID], - "Prompt %q has invalid Toolset.ID %q", prompt.Prompt.Name, prompt.Toolset.ID) - }) - } -} - // TestToolReadOnlyHintConsistency validates that read-only tools are correctly annotated func TestToolReadOnlyHintConsistency(t *testing.T) { tools := AllTools(stubTranslation) @@ -190,17 +150,6 @@ func TestAllToolsHaveHandlerFunc(t *testing.T) { } } -// TestDefaultToolsetsAreValid ensures default toolset IDs are all valid -func TestDefaultToolsetsAreValid(t *testing.T) { - defaults := GetDefaultToolsetIDs() - valid := GetValidToolsetIDs() - - for _, id := range defaults { - assert.True(t, valid[id], - "Default toolset ID %q is not in the valid toolset list", id) - } -} - // TestToolsetMetadataConsistency ensures tools in the same toolset have consistent descriptions func TestToolsetMetadataConsistency(t *testing.T) { tools := AllTools(stubTranslation) diff --git a/pkg/github/toolset_group.go b/pkg/github/toolset_group.go index 68db02db6..7330e08d3 100644 --- a/pkg/github/toolset_group.go +++ b/pkg/github/toolset_group.go @@ -5,17 +5,14 @@ import ( "github.com/github/github-mcp-server/pkg/translations" ) -// NewToolsetGroup creates a ToolsetGroup with all available tools, resources, and prompts. +// NewRegistry creates a Registry with all available tools, resources, and prompts. // Tools, resources, and prompts are self-describing with their toolset metadata embedded. // This function is stateless - no dependencies are captured. // Handlers are generated on-demand during registration via RegisterAll(ctx, server, deps). -// The "default" keyword in WithToolsets will expand to GetDefaultToolsetIDs(). -func NewToolsetGroup(t translations.TranslationHelperFunc) *toolsets.ToolsetGroup { - tsg := toolsets.NewToolsetGroup( - AllTools(t), - AllResources(t), - AllPrompts(t), - ) - tsg.SetDefaultToolsetIDs(GetDefaultToolsetIDs()) - return tsg +// The "default" keyword in WithToolsets will expand to toolsets marked with Default: true. +func NewRegistry(t translations.TranslationHelperFunc) *toolsets.Registry { + return toolsets.NewRegistry(). + SetTools(AllTools(t)). + SetResources(AllResources(t)). + SetPrompts(AllPrompts(t)) } diff --git a/pkg/toolsets/server_tool.go b/pkg/toolsets/server_tool.go index 0e782e631..eb30f01f4 100644 --- a/pkg/toolsets/server_tool.go +++ b/pkg/toolsets/server_tool.go @@ -24,6 +24,8 @@ type ToolsetMetadata struct { ID ToolsetID // Description provides a human-readable description of the toolset Description string + // Default indicates this toolset should be enabled by default + Default bool } // ServerTool represents an MCP tool with metadata and a handler generator function. diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 82896510e..34e5fa923 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -6,6 +6,7 @@ import ( "os" "slices" "sort" + "strings" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -116,14 +117,14 @@ func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.Pro } } -// ToolsetGroup holds a collection of tools, resources, and prompts. -// It supports immutable filtering operations that return new ToolsetGroups +// Registry holds a collection of tools, resources, and prompts. +// It supports immutable filtering operations that return new Registrys // without modifying the original. This design allows for: // - Building a full set of tools/resources/prompts once // - Applying filters (read-only, feature flags, enabled toolsets) without mutation // - Deterministic ordering for documentation generation // - Lazy dependency injection only when registering with a server -type ToolsetGroup struct { +type Registry struct { // tools holds all tools in this group tools []ServerTool // resourceTemplates holds all resource templates in this group @@ -132,8 +133,6 @@ type ToolsetGroup struct { prompts []ServerPrompt // deprecatedAliases maps old tool names to new canonical names deprecatedAliases map[string]string - // defaultToolsetIDs are the toolset IDs that "default" expands to - defaultToolsetIDs []ToolsetID // Filters - these control what's returned by Available* methods // readOnly when true filters out write tools @@ -148,6 +147,8 @@ type ToolsetGroup struct { // Takes context and flag name, returns (enabled, error). If error, log and treat as false. // If checker is nil, all flag checks return false. featureChecker FeatureFlagChecker + // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets + unrecognizedToolsets []string } // FeatureFlagChecker is a function that checks if a feature flag is enabled. @@ -155,78 +156,99 @@ type ToolsetGroup struct { // Returns (enabled, error). If error occurs, the caller should log and treat as false. type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) -// NewToolsetGroup creates a new ToolsetGroup from the provided tools, resources, and prompts. -// The group is created with no filters applied. -func NewToolsetGroup(tools []ServerTool, resources []ServerResourceTemplate, prompts []ServerPrompt) *ToolsetGroup { - return &ToolsetGroup{ - tools: tools, - resourceTemplates: resources, - prompts: prompts, +// NewRegistry creates a new empty Registry. +// Use SetTools, SetResources, SetPrompts to populate it. +func NewRegistry() *Registry { + return &Registry{ deprecatedAliases: make(map[string]string), - readOnly: false, - enabledToolsets: nil, - additionalTools: nil, - featureChecker: nil, } } -// copy creates a shallow copy of the ToolsetGroup for immutable operations. -func (tg *ToolsetGroup) copy() *ToolsetGroup { - newTG := &ToolsetGroup{ - tools: tg.tools, // slices are shared (immutable) - resourceTemplates: tg.resourceTemplates, - prompts: tg.prompts, - deprecatedAliases: tg.deprecatedAliases, - defaultToolsetIDs: tg.defaultToolsetIDs, - readOnly: tg.readOnly, - featureChecker: tg.featureChecker, +// SetTools sets the tools for this group. Returns self for chaining. +func (r *Registry) SetTools(tools []ServerTool) *Registry { + r.tools = tools + return r +} + +// SetResources sets the resource templates for this group. Returns self for chaining. +func (r *Registry) SetResources(resources []ServerResourceTemplate) *Registry { + r.resourceTemplates = resources + return r +} + +// SetPrompts sets the prompts for this group. Returns self for chaining. +func (r *Registry) SetPrompts(prompts []ServerPrompt) *Registry { + r.prompts = prompts + return r +} + +// copy creates a shallow copy of the Registry for immutable operations. +func (r *Registry) copy() *Registry { + newTG := &Registry{ + tools: r.tools, // slices are shared (immutable) + resourceTemplates: r.resourceTemplates, + prompts: r.prompts, + deprecatedAliases: r.deprecatedAliases, + readOnly: r.readOnly, + featureChecker: r.featureChecker, } // Copy maps if they exist - if tg.enabledToolsets != nil { - newTG.enabledToolsets = make(map[ToolsetID]bool, len(tg.enabledToolsets)) - for k, v := range tg.enabledToolsets { + if r.enabledToolsets != nil { + newTG.enabledToolsets = make(map[ToolsetID]bool, len(r.enabledToolsets)) + for k, v := range r.enabledToolsets { newTG.enabledToolsets[k] = v } } - if tg.additionalTools != nil { - newTG.additionalTools = make(map[string]bool, len(tg.additionalTools)) - for k, v := range tg.additionalTools { + if r.additionalTools != nil { + newTG.additionalTools = make(map[string]bool, len(r.additionalTools)) + for k, v := range r.additionalTools { newTG.additionalTools[k] = v } } + newTG.unrecognizedToolsets = r.unrecognizedToolsets return newTG } -// WithReadOnly returns a new ToolsetGroup with read-only mode set. +// WithReadOnly returns a new Registry with read-only mode set. // When true, write tools are filtered out from Available* methods. -func (tg *ToolsetGroup) WithReadOnly(readOnly bool) *ToolsetGroup { - newTG := tg.copy() +func (r *Registry) WithReadOnly(readOnly bool) *Registry { + newTG := r.copy() newTG.readOnly = readOnly return newTG } -// SetDefaultToolsetIDs configures which toolset IDs the "default" keyword expands to. -// This should be called before WithToolsets if you want "default" to be recognized. -func (tg *ToolsetGroup) SetDefaultToolsetIDs(ids []ToolsetID) *ToolsetGroup { - tg.defaultToolsetIDs = ids - return tg -} - -// WithToolsets returns a new ToolsetGroup that only includes items from the specified toolsets. +// WithToolsets returns a new Registry that only includes items from the specified toolsets. // Special keywords: // - "all": enables all toolsets -// - "default": expands to the default toolset IDs (set via SetDefaultToolsetIDs) +// - "default": expands to toolsets marked with Default: true in their metadata +// +// Input strings are trimmed of whitespace and duplicates are removed. +// Toolset IDs that don't match any registered toolsets are tracked and can be +// retrieved via UnrecognizedToolsets() for warning purposes. // // Pass nil to use default toolsets. Pass an empty slice to disable all toolsets // (useful for dynamic toolsets mode where tools are enabled on demand). -func (tg *ToolsetGroup) WithToolsets(toolsetIDs []string) *ToolsetGroup { - newTG := tg.copy() +func (r *Registry) WithToolsets(toolsetIDs []string) *Registry { + newTG := r.copy() + newTG.unrecognizedToolsets = nil // reset for fresh calculation + + // Build a set of valid toolset IDs for validation + validIDs := make(map[ToolsetID]bool) + for _, t := range r.tools { + validIDs[t.Toolset.ID] = true + } + for _, r := range r.resourceTemplates { + validIDs[r.Toolset.ID] = true + } + for _, p := range r.prompts { + validIDs[p.Toolset.ID] = true + } // Check for "all" keyword - enables all toolsets for _, id := range toolsetIDs { - if id == "all" { + if strings.TrimSpace(id) == "all" { newTG.enabledToolsets = nil return newTG } @@ -237,26 +259,38 @@ func (tg *ToolsetGroup) WithToolsets(toolsetIDs []string) *ToolsetGroup { toolsetIDs = []string{"default"} } - // Expand "default" keyword and collect other IDs + // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized seen := make(map[ToolsetID]bool) expanded := make([]ToolsetID, 0, len(toolsetIDs)) + var unrecognized []string + for _, id := range toolsetIDs { - if id == "default" { - for _, defaultID := range tg.defaultToolsetIDs { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + continue + } + if trimmed == "default" { + for _, defaultID := range r.DefaultToolsetIDs() { if !seen[defaultID] { seen[defaultID] = true expanded = append(expanded, defaultID) } } } else { - tsID := ToolsetID(id) + tsID := ToolsetID(trimmed) if !seen[tsID] { seen[tsID] = true expanded = append(expanded, tsID) + // Track if this toolset doesn't exist + if !validIDs[tsID] { + unrecognized = append(unrecognized, trimmed) + } } } } + newTG.unrecognizedToolsets = unrecognized + if len(expanded) == 0 { newTG.enabledToolsets = make(map[ToolsetID]bool) return newTG @@ -269,13 +303,19 @@ func (tg *ToolsetGroup) WithToolsets(toolsetIDs []string) *ToolsetGroup { return newTG } -// WithTools returns a new ToolsetGroup with additional tools that bypass toolset filtering. +// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't +// match any registered toolsets. This is useful for warning users about typos. +func (r *Registry) UnrecognizedToolsets() []string { + return r.unrecognizedToolsets +} + +// WithTools returns a new Registry with additional tools that bypass toolset filtering. // These tools are additive - they will be included even if their toolset is not enabled. // Read-only filtering still applies to these tools. // Deprecated tool aliases are automatically resolved to their canonical names. // Pass nil or empty slice to clear additional tools. -func (tg *ToolsetGroup) WithTools(toolNames []string) *ToolsetGroup { - newTG := tg.copy() +func (r *Registry) WithTools(toolNames []string) *Registry { + newTG := r.copy() if len(toolNames) == 0 { newTG.additionalTools = nil return newTG @@ -283,7 +323,7 @@ func (tg *ToolsetGroup) WithTools(toolNames []string) *ToolsetGroup { newTG.additionalTools = make(map[string]bool, len(toolNames)) for _, name := range toolNames { // Resolve deprecated aliases to canonical names - if canonical, isAlias := tg.deprecatedAliases[name]; isAlias { + if canonical, isAlias := r.deprecatedAliases[name]; isAlias { newTG.additionalTools[canonical] = true } else { newTG.additionalTools[name] = true @@ -292,13 +332,13 @@ func (tg *ToolsetGroup) WithTools(toolNames []string) *ToolsetGroup { return newTG } -// WithFeatureChecker returns a new ToolsetGroup with a feature checker function. +// WithFeatureChecker returns a new Registry with a feature checker function. // The checker receives a context (for actor extraction) and feature flag name, returns (enabled, error). // If error occurs, it will be logged and treated as false. // If checker is nil, all feature flag checks return false (items with FeatureFlagEnable are excluded, // items with FeatureFlagDisable are included). -func (tg *ToolsetGroup) WithFeatureChecker(checker FeatureFlagChecker) *ToolsetGroup { - newTG := tg.copy() +func (r *Registry) WithFeatureChecker(checker FeatureFlagChecker) *Registry { + newTG := r.copy() newTG.featureChecker = checker return newTG } @@ -315,7 +355,7 @@ const ( MCPMethodPromptsGet = "prompts/get" ) -// ForMCPRequest returns a ToolsetGroup optimized for a specific MCP request. +// ForMCPRequest returns a Registry optimized for a specific MCP request. // This is designed for servers that create a new instance per request (like the remote server), // allowing them to only register the items needed for that specific request rather than all ~90 tools. // @@ -323,7 +363,7 @@ const ( // - method: The MCP method being called (use MCP* constants) // - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) // -// Returns a new ToolsetGroup containing only the items relevant to the request: +// Returns a new Registry containing only the items relevant to the request: // - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) // - MCPMethodToolsList: All available tools (no resources/prompts) // - MCPMethodToolsCall: Only the named tool @@ -334,8 +374,8 @@ const ( // - Unknown methods: Empty (no items registered) // // All existing filters (read-only, toolsets, etc.) still apply to the returned items. -func (tg *ToolsetGroup) ForMCPRequest(method string, itemName string) *ToolsetGroup { - result := tg.copy() +func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { + result := r.copy() // Helper to clear all item types clearAll := func() { @@ -352,21 +392,21 @@ func (tg *ToolsetGroup) ForMCPRequest(method string, itemName string) *ToolsetGr case MCPMethodToolsCall: result.resourceTemplates, result.prompts = nil, nil if itemName != "" { - result.tools = tg.filterToolsByName(itemName) + result.tools = r.filterToolsByName(itemName) } case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: result.tools, result.prompts = nil, nil case MCPMethodResourcesRead: result.tools, result.prompts = nil, nil if itemName != "" { - result.resourceTemplates = tg.filterResourcesByURI(itemName) + result.resourceTemplates = r.filterResourcesByURI(itemName) } case MCPMethodPromptsList: result.tools, result.resourceTemplates = nil, nil case MCPMethodPromptsGet: result.tools, result.resourceTemplates = nil, nil if itemName != "" { - result.prompts = tg.filterPromptsByName(itemName) + result.prompts = r.filterPromptsByName(itemName) } default: clearAll() @@ -377,18 +417,18 @@ func (tg *ToolsetGroup) ForMCPRequest(method string, itemName string) *ToolsetGr // filterToolsByName returns tools matching the given name, checking deprecated aliases. // Returns from the current tools slice (respects existing filter chain). -func (tg *ToolsetGroup) filterToolsByName(name string) []ServerTool { +func (r *Registry) filterToolsByName(name string) []ServerTool { // First check for exact match - for i := range tg.tools { - if tg.tools[i].Tool.Name == name { - return []ServerTool{tg.tools[i]} + for i := range r.tools { + if r.tools[i].Tool.Name == name { + return []ServerTool{r.tools[i]} } } // Check if name is a deprecated alias - if canonical, isAlias := tg.deprecatedAliases[name]; isAlias { - for i := range tg.tools { - if tg.tools[i].Tool.Name == canonical { - return []ServerTool{tg.tools[i]} + if canonical, isAlias := r.deprecatedAliases[name]; isAlias { + for i := range r.tools { + if r.tools[i].Tool.Name == canonical { + return []ServerTool{r.tools[i]} } } } @@ -396,33 +436,33 @@ func (tg *ToolsetGroup) filterToolsByName(name string) []ServerTool { } // filterResourcesByURI returns resource templates matching the given URI pattern. -func (tg *ToolsetGroup) filterResourcesByURI(uri string) []ServerResourceTemplate { - for i := range tg.resourceTemplates { +func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { + for i := range r.resourceTemplates { // Check if URI matches the template pattern (exact match on URITemplate string) - if tg.resourceTemplates[i].Template.URITemplate == uri { - return []ServerResourceTemplate{tg.resourceTemplates[i]} + if r.resourceTemplates[i].Template.URITemplate == uri { + return []ServerResourceTemplate{r.resourceTemplates[i]} } } return []ServerResourceTemplate{} } // filterPromptsByName returns prompts matching the given name. -func (tg *ToolsetGroup) filterPromptsByName(name string) []ServerPrompt { - for i := range tg.prompts { - if tg.prompts[i].Prompt.Name == name { - return []ServerPrompt{tg.prompts[i]} +func (r *Registry) filterPromptsByName(name string) []ServerPrompt { + for i := range r.prompts { + if r.prompts[i].Prompt.Name == name { + return []ServerPrompt{r.prompts[i]} } } return []ServerPrompt{} } -// WithDeprecatedToolAliases returns a new ToolsetGroup with the given deprecated aliases added. +// WithDeprecatedToolAliases returns a new Registry with the given deprecated aliases added. // Aliases map old tool names to new canonical names. -func (tg *ToolsetGroup) WithDeprecatedToolAliases(aliases map[string]string) *ToolsetGroup { - newTG := tg.copy() +func (r *Registry) WithDeprecatedToolAliases(aliases map[string]string) *Registry { + newTG := r.copy() // Ensure we have a fresh map - newTG.deprecatedAliases = make(map[string]string, len(tg.deprecatedAliases)+len(aliases)) - for k, v := range tg.deprecatedAliases { + newTG.deprecatedAliases = make(map[string]string, len(r.deprecatedAliases)+len(aliases)) + for k, v := range r.deprecatedAliases { newTG.deprecatedAliases[k] = v } for oldName, newName := range aliases { @@ -432,21 +472,21 @@ func (tg *ToolsetGroup) WithDeprecatedToolAliases(aliases map[string]string) *To } // isToolsetEnabled checks if a toolset is enabled based on current filters. -func (tg *ToolsetGroup) isToolsetEnabled(toolsetID ToolsetID) bool { +func (r *Registry) isToolsetEnabled(toolsetID ToolsetID) bool { // Check enabled toolsets filter - if tg.enabledToolsets != nil { - return tg.enabledToolsets[toolsetID] + if r.enabledToolsets != nil { + return r.enabledToolsets[toolsetID] } return true } // checkFeatureFlag checks a feature flag using the feature checker. // Returns false if checker is nil or returns an error (errors are logged). -func (tg *ToolsetGroup) checkFeatureFlag(ctx context.Context, flagName string) bool { - if tg.featureChecker == nil || flagName == "" { +func (r *Registry) checkFeatureFlag(ctx context.Context, flagName string) bool { + if r.featureChecker == nil || flagName == "" { return false } - enabled, err := tg.featureChecker(ctx, flagName) + enabled, err := r.featureChecker(ctx, flagName) if err != nil { fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) return false @@ -457,34 +497,34 @@ func (tg *ToolsetGroup) checkFeatureFlag(ctx context.Context, flagName string) b // isFeatureFlagAllowed checks if an item passes feature flag filtering. // - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled // - If FeatureFlagDisable is set, the item is excluded if the flag is enabled -func (tg *ToolsetGroup) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { +func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { // Check enable flag - item requires this flag to be on - if enableFlag != "" && !tg.checkFeatureFlag(ctx, enableFlag) { + if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { return false } // Check disable flag - item is excluded if this flag is on - if disableFlag != "" && tg.checkFeatureFlag(ctx, disableFlag) { + if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { return false } return true } // isToolEnabled checks if a specific tool is enabled based on current filters. -func (tg *ToolsetGroup) isToolEnabled(ctx context.Context, tool *ServerTool) bool { +func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { // Check read-only filter first (applies to all tools) - if tg.readOnly && !tool.IsReadOnly() { + if r.readOnly && !tool.IsReadOnly() { return false } // Check feature flags - if !tg.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { + if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { return false } // Check if tool is in additionalTools (bypasses toolset filter) - if tg.additionalTools != nil && tg.additionalTools[tool.Tool.Name] { + if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { return true } // Check toolset filter - if !tg.isToolsetEnabled(tool.Toolset.ID) { + if !r.isToolsetEnabled(tool.Toolset.ID) { return false } return true @@ -493,11 +533,11 @@ func (tg *ToolsetGroup) isToolEnabled(ctx context.Context, tool *ServerTool) boo // AvailableTools returns the tools that pass all current filters, // sorted deterministically by toolset ID, then tool name. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) AvailableTools(ctx context.Context) []ServerTool { +func (r *Registry) AvailableTools(ctx context.Context) []ServerTool { var result []ServerTool - for i := range tg.tools { - tool := &tg.tools[i] - if tg.isToolEnabled(ctx, tool) { + for i := range r.tools { + tool := &r.tools[i] + if r.isToolEnabled(ctx, tool) { result = append(result, *tool) } } @@ -516,15 +556,15 @@ func (tg *ToolsetGroup) AvailableTools(ctx context.Context) []ServerTool { // AvailableResourceTemplates returns resource templates that pass all current filters, // sorted deterministically by toolset ID, then template name. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { +func (r *Registry) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { var result []ServerResourceTemplate - for i := range tg.resourceTemplates { - res := &tg.resourceTemplates[i] + for i := range r.resourceTemplates { + res := &r.resourceTemplates[i] // Check feature flags - if !tg.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { + if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { continue } - if tg.isToolsetEnabled(res.Toolset.ID) { + if r.isToolsetEnabled(res.Toolset.ID) { result = append(result, *res) } } @@ -543,15 +583,15 @@ func (tg *ToolsetGroup) AvailableResourceTemplates(ctx context.Context) []Server // AvailablePrompts returns prompts that pass all current filters, // sorted deterministically by toolset ID, then prompt name. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) AvailablePrompts(ctx context.Context) []ServerPrompt { +func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { var result []ServerPrompt - for i := range tg.prompts { - prompt := &tg.prompts[i] + for i := range r.prompts { + prompt := &r.prompts[i] // Check feature flags - if !tg.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { + if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { continue } - if tg.isToolsetEnabled(prompt.Toolset.ID) { + if r.isToolsetEnabled(prompt.Toolset.ID) { result = append(result, *prompt) } } @@ -568,16 +608,44 @@ func (tg *ToolsetGroup) AvailablePrompts(ctx context.Context) []ServerPrompt { } // ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. -func (tg *ToolsetGroup) ToolsetIDs() []ToolsetID { +func (r *Registry) ToolsetIDs() []ToolsetID { + seen := make(map[ToolsetID]bool) + for i := range r.tools { + seen[r.tools[i].Toolset.ID] = true + } + for i := range r.resourceTemplates { + seen[r.resourceTemplates[i].Toolset.ID] = true + } + for i := range r.prompts { + seen[r.prompts[i].Toolset.ID] = true + } + + ids := make([]ToolsetID, 0, len(seen)) + for id := range seen { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. +// The IDs are returned in sorted order for deterministic output. +func (r *Registry) DefaultToolsetIDs() []ToolsetID { seen := make(map[ToolsetID]bool) - for i := range tg.tools { - seen[tg.tools[i].Toolset.ID] = true + for i := range r.tools { + if r.tools[i].Toolset.Default { + seen[r.tools[i].Toolset.ID] = true + } } - for i := range tg.resourceTemplates { - seen[tg.resourceTemplates[i].Toolset.ID] = true + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Toolset.Default { + seen[r.resourceTemplates[i].Toolset.ID] = true + } } - for i := range tg.prompts { - seen[tg.prompts[i].Toolset.ID] = true + for i := range r.prompts { + if r.prompts[i].Toolset.Default { + seen[r.prompts[i].Toolset.ID] = true + } } ids := make([]ToolsetID, 0, len(seen)) @@ -589,22 +657,22 @@ func (tg *ToolsetGroup) ToolsetIDs() []ToolsetID { } // ToolsetDescriptions returns a map of toolset ID to description for all toolsets. -func (tg *ToolsetGroup) ToolsetDescriptions() map[ToolsetID]string { +func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { descriptions := make(map[ToolsetID]string) - for i := range tg.tools { - t := &tg.tools[i] + for i := range r.tools { + t := &r.tools[i] if t.Toolset.Description != "" { descriptions[t.Toolset.ID] = t.Toolset.Description } } - for i := range tg.resourceTemplates { - r := &tg.resourceTemplates[i] + for i := range r.resourceTemplates { + r := &r.resourceTemplates[i] if r.Toolset.Description != "" { descriptions[r.Toolset.ID] = r.Toolset.Description } } - for i := range tg.prompts { - p := &tg.prompts[i] + for i := range r.prompts { + p := &r.prompts[i] if p.Toolset.Description != "" { descriptions[p.Toolset.ID] = p.Toolset.Description } @@ -615,13 +683,13 @@ func (tg *ToolsetGroup) ToolsetDescriptions() map[ToolsetID]string { // ToolsForToolset returns all tools belonging to a specific toolset. // This method bypasses the toolset enabled filter (for dynamic toolset registration), // but still respects the read-only filter. -func (tg *ToolsetGroup) ToolsForToolset(toolsetID ToolsetID) []ServerTool { +func (r *Registry) ToolsForToolset(toolsetID ToolsetID) []ServerTool { var result []ServerTool - for i := range tg.tools { - tool := &tg.tools[i] + for i := range r.tools { + tool := &r.tools[i] // Only check read-only filter, not toolset enabled filter if tool.Toolset.ID == toolsetID { - if tg.readOnly && !tool.IsReadOnly() { + if r.readOnly && !tool.IsReadOnly() { continue } result = append(result, *tool) @@ -638,34 +706,34 @@ func (tg *ToolsetGroup) ToolsForToolset(toolsetID ToolsetID) []ServerTool { // RegisterTools registers all available tools with the server using the provided dependencies. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { - for _, tool := range tg.AvailableTools(ctx) { +func (r *Registry) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { + for _, tool := range r.AvailableTools(ctx) { tool.RegisterFunc(s, deps) } } // RegisterResourceTemplates registers all available resource templates with the server. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { - for _, res := range tg.AvailableResourceTemplates(ctx) { +func (r *Registry) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { + for _, res := range r.AvailableResourceTemplates(ctx) { s.AddResourceTemplate(&res.Template, res.Handler(deps)) } } // RegisterPrompts registers all available prompts with the server. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) RegisterPrompts(ctx context.Context, s *mcp.Server) { - for _, prompt := range tg.AvailablePrompts(ctx) { +func (r *Registry) RegisterPrompts(ctx context.Context, s *mcp.Server) { + for _, prompt := range r.AvailablePrompts(ctx) { s.AddPrompt(&prompt.Prompt, prompt.Handler) } } // RegisterAll registers all available tools, resources, and prompts with the server. // The context is used for feature flag evaluation. -func (tg *ToolsetGroup) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { - tg.RegisterTools(ctx, s, deps) - tg.RegisterResourceTemplates(ctx, s, deps) - tg.RegisterPrompts(ctx, s) +func (r *Registry) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { + r.RegisterTools(ctx, s, deps) + r.RegisterResourceTemplates(ctx, s, deps) + r.RegisterPrompts(ctx, s) } // ResolveToolAliases resolves deprecated tool aliases to their canonical names. @@ -673,11 +741,11 @@ func (tg *ToolsetGroup) RegisterAll(ctx context.Context, s *mcp.Server, deps any // Returns: // - resolved: tool names with aliases replaced by canonical names // - aliasesUsed: map of oldName → newName for each alias that was resolved -func (tg *ToolsetGroup) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { +func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { resolved = make([]string, 0, len(toolNames)) aliasesUsed = make(map[string]string) for _, toolName := range toolNames { - if canonicalName, isAlias := tg.deprecatedAliases[toolName]; isAlias { + if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) aliasesUsed[toolName] = canonicalName resolved = append(resolved, canonicalName) @@ -691,9 +759,9 @@ func (tg *ToolsetGroup) ResolveToolAliases(toolNames []string) (resolved []strin // FindToolByName searches all tools for one matching the given name. // Returns the tool, its toolset ID, and an error if not found. // This searches ALL tools regardless of filters. -func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { - for i := range tg.tools { - tool := &tg.tools[i] +func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { + for i := range r.tools { + tool := &r.tools[i] if tool.Tool.Name == toolName { return tool, tool.Toolset.ID, nil } @@ -702,19 +770,19 @@ func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, ToolsetID, } // HasToolset checks if any tool/resource/prompt belongs to the given toolset. -func (tg *ToolsetGroup) HasToolset(toolsetID ToolsetID) bool { - for i := range tg.tools { - if tg.tools[i].Toolset.ID == toolsetID { +func (r *Registry) HasToolset(toolsetID ToolsetID) bool { + for i := range r.tools { + if r.tools[i].Toolset.ID == toolsetID { return true } } - for i := range tg.resourceTemplates { - if tg.resourceTemplates[i].Toolset.ID == toolsetID { + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Toolset.ID == toolsetID { return true } } - for i := range tg.prompts { - if tg.prompts[i].Toolset.ID == toolsetID { + for i := range r.prompts { + if r.prompts[i].Toolset.ID == toolsetID { return true } } @@ -723,14 +791,14 @@ func (tg *ToolsetGroup) HasToolset(toolsetID ToolsetID) bool { // EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. // Returns all toolset IDs if no filter is set. -func (tg *ToolsetGroup) EnabledToolsetIDs() []ToolsetID { - if tg.enabledToolsets == nil { - return tg.ToolsetIDs() +func (r *Registry) EnabledToolsetIDs() []ToolsetID { + if r.enabledToolsets == nil { + return r.ToolsetIDs() } - ids := make([]ToolsetID, 0, len(tg.enabledToolsets)) - for id := range tg.enabledToolsets { - if tg.HasToolset(id) { + ids := make([]ToolsetID, 0, len(r.enabledToolsets)) + for id := range r.enabledToolsets { + if r.HasToolset(id) { ids = append(ids, id) } } @@ -739,23 +807,23 @@ func (tg *ToolsetGroup) EnabledToolsetIDs() []ToolsetID { } // IsToolsetEnabled checks if a toolset is currently enabled based on filters. -func (tg *ToolsetGroup) IsToolsetEnabled(toolsetID ToolsetID) bool { - return tg.isToolsetEnabled(toolsetID) +func (r *Registry) IsToolsetEnabled(toolsetID ToolsetID) bool { + return r.isToolsetEnabled(toolsetID) } // EnableToolset marks a toolset as enabled in this group. // This is used by dynamic toolset management to track which toolsets have been enabled. -func (tg *ToolsetGroup) EnableToolset(toolsetID ToolsetID) { - if tg.enabledToolsets == nil { +func (r *Registry) EnableToolset(toolsetID ToolsetID) { + if r.enabledToolsets == nil { // nil means all enabled, so nothing to do return } - tg.enabledToolsets[toolsetID] = true + r.enabledToolsets[toolsetID] = true } // AllTools returns all tools without any filtering, sorted deterministically. -func (tg *ToolsetGroup) AllTools() []ServerTool { - result := slices.Clone(tg.tools) +func (r *Registry) AllTools() []ServerTool { + result := slices.Clone(r.tools) // Sort deterministically: by toolset ID, then by tool name sort.Slice(result, func(i, j int) bool { @@ -772,8 +840,8 @@ func (tg *ToolsetGroup) AllTools() []ServerTool { // This is the ordered intersection of toolsets with reality - only toolsets that // actually contain tools are returned, sorted by toolset ID. // Optional exclude parameter filters out specific toolset IDs from the result. -func (tg *ToolsetGroup) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { - tools := tg.AllTools() +func (r *Registry) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { + tools := r.AllTools() if len(tools) == 0 { return nil } diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index 317dafbf6..0d1c35e2e 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -17,6 +17,34 @@ func testToolsetMetadata(id string) ToolsetMetadata { } } +// testToolsetMetadataWithDefault returns a ToolsetMetadata with Default flag for testing +func testToolsetMetadataWithDefault(id string, isDefault bool) ToolsetMetadata { + return ToolsetMetadata{ + ID: ToolsetID(id), + Description: "Test toolset: " + id, + Default: isDefault, + } +} + +// mockToolWithDefault creates a mock tool with a default toolset flag +func mockToolWithDefault(name string, toolsetID string, readOnly bool, isDefault bool) ServerTool { + return NewServerToolFromHandler( + mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ + ReadOnlyHint: readOnly, + }, + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }, + testToolsetMetadataWithDefault(toolsetID, isDefault), + func(_ any) mcp.ToolHandler { + return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + }, + ) +} + // mockTool creates a minimal ServerTool for testing func mockTool(name string, toolsetID string, readOnly bool) ServerTool { return NewServerToolFromHandler( @@ -36,8 +64,8 @@ func mockTool(name string, toolsetID string, readOnly bool) ServerTool { ) } -func TestNewToolsetGroupEmpty(t *testing.T) { - tsg := NewToolsetGroup(nil, nil, nil) +func TestNewRegistryEmpty(t *testing.T) { + tsg := NewRegistry() if len(tsg.tools) != 0 { t.Fatalf("Expected tools to be empty, got %d items", len(tsg.tools)) } @@ -49,14 +77,14 @@ func TestNewToolsetGroupEmpty(t *testing.T) { } } -func TestNewToolsetGroupWithTools(t *testing.T) { +func TestNewRegistryWithTools(t *testing.T) { tools := []ServerTool{ mockTool("tool1", "toolset1", true), mockTool("tool2", "toolset1", false), mockTool("tool3", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) if len(tsg.tools) != 3 { t.Errorf("Expected 3 tools, got %d", len(tsg.tools)) @@ -70,7 +98,7 @@ func TestAvailableTools_NoFilters(t *testing.T) { mockTool("tool_c", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) available := tsg.AvailableTools(context.Background()) if len(available) != 3 { @@ -92,7 +120,7 @@ func TestWithReadOnly(t *testing.T) { mockTool("write_tool", "toolset1", false), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Original should have both tools allTools := tsg.AvailableTools(context.Background()) @@ -124,11 +152,11 @@ func TestWithToolsets(t *testing.T) { mockTool("tool3", "toolset3", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Filter to specific toolsets - filteredTsg := tsg.WithToolsets([]string{"toolset1", "toolset3"}) - filteredTools := filteredTsg.AvailableTools(context.Background()) + filteredReg := tsg.WithToolsets([]string{"toolset1", "toolset3"}) + filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) @@ -150,6 +178,114 @@ func TestWithToolsets(t *testing.T) { } } +func TestWithToolsetsTrimsWhitespace(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + tsg := NewRegistry().SetTools(tools) + + // Whitespace should be trimmed + filteredReg := tsg.WithToolsets([]string{" toolset1 ", " toolset2 "}) + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 tools after whitespace trimming, got %d", len(filteredTools)) + } +} + +func TestWithToolsetsDeduplicates(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + tsg := NewRegistry().SetTools(tools) + + // Duplicates should be removed + filteredReg := tsg.WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}) + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 1 { + t.Fatalf("Expected 1 tool after deduplication, got %d", len(filteredTools)) + } +} + +func TestWithToolsetsIgnoresEmptyStrings(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + tsg := NewRegistry().SetTools(tools) + + // Empty strings should be ignored + filteredReg := tsg.WithToolsets([]string{"", "toolset1", " ", ""}) + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(filteredTools)) + } +} + +func TestUnrecognizedToolsets(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + tsg := NewRegistry().SetTools(tools) + + tests := []struct { + name string + input []string + expectedUnrecognized []string + }{ + { + name: "all valid", + input: []string{"toolset1", "toolset2"}, + expectedUnrecognized: nil, + }, + { + name: "one invalid", + input: []string{"toolset1", "invalid_toolset"}, + expectedUnrecognized: []string{"invalid_toolset"}, + }, + { + name: "multiple invalid", + input: []string{"typo1", "toolset1", "typo2"}, + expectedUnrecognized: []string{"typo1", "typo2"}, + }, + { + name: "invalid with whitespace trimmed", + input: []string{" invalid_tool "}, + expectedUnrecognized: []string{"invalid_tool"}, + }, + { + name: "empty input", + input: []string{}, + expectedUnrecognized: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filtered := tsg.WithToolsets(tt.input) + unrecognized := filtered.UnrecognizedToolsets() + + if len(unrecognized) != len(tt.expectedUnrecognized) { + t.Fatalf("Expected %d unrecognized, got %d: %v", + len(tt.expectedUnrecognized), len(unrecognized), unrecognized) + } + + for i, expected := range tt.expectedUnrecognized { + if unrecognized[i] != expected { + t.Errorf("Expected unrecognized[%d] = %q, got %q", i, expected, unrecognized[i]) + } + } + }) + } +} + func TestWithTools(t *testing.T) { tools := []ServerTool{ mockTool("tool1", "toolset1", true), @@ -157,12 +293,12 @@ func TestWithTools(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // WithTools adds additional tools that bypass toolset filtering // When combined with WithToolsets([]), only the additional tools should be available - filteredTsg := tsg.WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}) - filteredTools := filteredTsg.AvailableTools(context.Background()) + filteredReg := tsg.WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}) + filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) @@ -185,7 +321,7 @@ func TestChainedFilters(t *testing.T) { mockTool("write2", "toolset2", false), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Chain read-only and toolset filter filtered := tsg.WithReadOnly(true).WithToolsets([]string{"toolset1"}) @@ -206,7 +342,7 @@ func TestToolsetIDs(t *testing.T) { mockTool("tool3", "toolset_b", true), // duplicate toolset } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) ids := tsg.ToolsetIDs() if len(ids) != 2 { @@ -225,7 +361,7 @@ func TestToolsetDescriptions(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) descriptions := tsg.ToolsetDescriptions() if len(descriptions) != 2 { @@ -244,7 +380,7 @@ func TestToolsForToolset(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) toolset1Tools := tsg.ToolsForToolset("toolset1") if len(toolset1Tools) != 2 { @@ -257,7 +393,7 @@ func TestWithDeprecatedToolAliases(t *testing.T) { mockTool("new_name", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) tsgWithAliases := tsg.WithDeprecatedToolAliases(map[string]string{ "old_name": "new_name", "get_issue": "issue_read", @@ -282,7 +418,7 @@ func TestResolveToolAliases(t *testing.T) { mockTool("some_tool", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil). + tsg := NewRegistry().SetTools(tools). WithDeprecatedToolAliases(map[string]string{ "get_issue": "issue_read", }) @@ -314,7 +450,7 @@ func TestFindToolByName(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Find by name tool, toolsetID, err := tsg.FindToolByName("issue_read") @@ -342,7 +478,7 @@ func TestWithToolsAdditive(t *testing.T) { mockTool("repo_read", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Test WithTools bypasses toolset filtering // Enable only toolset2, but add issue_read as additional tool @@ -389,7 +525,7 @@ func TestWithToolsResolvesAliases(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil). + tsg := NewRegistry().SetTools(tools). WithDeprecatedToolAliases(map[string]string{ "get_issue": "issue_read", }) @@ -411,7 +547,7 @@ func TestHasToolset(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) if !tsg.HasToolset("toolset1") { t.Error("expected HasToolset to return true for existing toolset") @@ -427,7 +563,7 @@ func TestEnabledToolsetIDs(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Without filter, all toolsets are enabled ids := tsg.EnabledToolsetIDs() @@ -452,7 +588,7 @@ func TestAllTools(t *testing.T) { mockTool("write_tool", "toolset1", false), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Even with read-only filter, AllTools returns everything readOnlyTsg := tsg.WithReadOnly(true) @@ -520,7 +656,7 @@ func TestForMCPRequest_Initialize(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewToolsetGroup(tools, resources, prompts) + tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) filtered := tsg.ForMCPRequest(MCPMethodInitialize, "") // Initialize should return empty - capabilities come from ServerOptions @@ -547,7 +683,7 @@ func TestForMCPRequest_ToolsList(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewToolsetGroup(tools, resources, prompts) + tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) filtered := tsg.ForMCPRequest(MCPMethodToolsList, "") // tools/list should return all tools, no resources or prompts @@ -569,7 +705,7 @@ func TestForMCPRequest_ToolsCall(t *testing.T) { mockTool("list_repos", "repos", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "get_me") available := filtered.AvailableTools(context.Background()) @@ -586,7 +722,7 @@ func TestForMCPRequest_ToolsCall_NotFound(t *testing.T) { mockTool("get_me", "context", true), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") if len(filtered.AvailableTools(context.Background())) != 0 { @@ -600,7 +736,7 @@ func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { mockTool("list_commits", "repos", true), } - tsg := NewToolsetGroup(tools, nil, nil). + tsg := NewRegistry().SetTools(tools). WithDeprecatedToolAliases(map[string]string{ "old_get_me": "get_me", }) @@ -622,7 +758,7 @@ func TestForMCPRequest_ToolsCall_RespectsFilters(t *testing.T) { mockTool("create_issue", "issues", false), // write tool } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Apply read-only filter, then ForMCPRequest filtered := tsg.WithReadOnly(true).ForMCPRequest(MCPMethodToolsCall, "create_issue") @@ -645,7 +781,7 @@ func TestForMCPRequest_ResourcesList(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewToolsetGroup(tools, resources, prompts) + tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) filtered := tsg.ForMCPRequest(MCPMethodResourcesList, "") if len(filtered.AvailableTools(context.Background())) != 0 { @@ -665,7 +801,7 @@ func TestForMCPRequest_ResourcesRead(t *testing.T) { mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), } - tsg := NewToolsetGroup(nil, resources, nil) + tsg := NewRegistry().SetResources(resources) filtered := tsg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") available := filtered.AvailableResourceTemplates(context.Background()) @@ -689,7 +825,7 @@ func TestForMCPRequest_PromptsList(t *testing.T) { mockPrompt("prompt2", "issues"), } - tsg := NewToolsetGroup(tools, resources, prompts) + tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) filtered := tsg.ForMCPRequest(MCPMethodPromptsList, "") if len(filtered.AvailableTools(context.Background())) != 0 { @@ -709,7 +845,7 @@ func TestForMCPRequest_PromptsGet(t *testing.T) { mockPrompt("prompt2", "issues"), } - tsg := NewToolsetGroup(nil, nil, prompts) + tsg := NewRegistry().SetPrompts(prompts) filtered := tsg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") available := filtered.AvailablePrompts(context.Background()) @@ -732,7 +868,7 @@ func TestForMCPRequest_UnknownMethod(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewToolsetGroup(tools, resources, prompts) + tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) filtered := tsg.ForMCPRequest("unknown/method", "") // Unknown methods should return empty @@ -759,7 +895,7 @@ func TestForMCPRequest_Immutability(t *testing.T) { mockPrompt("prompt1", "repos"), } - original := NewToolsetGroup(tools, resources, prompts) + original := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) filtered := original.ForMCPRequest(MCPMethodToolsCall, "tool1") // Original should be unchanged @@ -787,14 +923,13 @@ func TestForMCPRequest_Immutability(t *testing.T) { func TestForMCPRequest_ChainedWithOtherFilters(t *testing.T) { tools := []ServerTool{ - mockTool("get_me", "context", true), - mockTool("create_issue", "issues", false), - mockTool("list_repos", "repos", true), - mockTool("delete_repo", "repos", false), + mockToolWithDefault("get_me", "context", true, true), // default toolset + mockToolWithDefault("create_issue", "issues", false, false), // not default + mockToolWithDefault("list_repos", "repos", true, true), // default toolset + mockToolWithDefault("delete_repo", "repos", false, true), // default but write } - tsg := NewToolsetGroup(tools, nil, nil) - tsg.SetDefaultToolsetIDs([]ToolsetID{"context", "repos"}) + tsg := NewRegistry().SetTools(tools) // Chain: default toolsets -> read-only -> specific method filtered := tsg. @@ -837,7 +972,7 @@ func TestForMCPRequest_ResourcesTemplatesList(t *testing.T) { mockResource("res1", "repos", "repo://{owner}/{repo}"), } - tsg := NewToolsetGroup(tools, resources, nil) + tsg := NewRegistry().SetTools(tools).SetResources(resources) filtered := tsg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") // Same behavior as resources/list @@ -886,7 +1021,7 @@ func TestFeatureFlagEnable(t *testing.T) { mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Without feature checker, tool with FeatureFlagEnable should be excluded available := tsg.AvailableTools(context.Background()) @@ -922,7 +1057,7 @@ func TestFeatureFlagDisable(t *testing.T) { mockToolWithFlags("disabled_by_flag", "toolset1", true, "", "kill_switch"), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Without feature checker, tool with FeatureFlagDisable should be included (flag is false) available := tsg.AvailableTools(context.Background()) @@ -950,7 +1085,7 @@ func TestFeatureFlagBoth(t *testing.T) { mockToolWithFlags("complex_tool", "toolset1", true, "new_feature", "kill_switch"), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Enable flag not set -> excluded checker1 := func(_ context.Context, _ string) (bool, error) { return false, nil } @@ -976,7 +1111,7 @@ func TestFeatureFlagError(t *testing.T) { mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), } - tsg := NewToolsetGroup(tools, nil, nil) + tsg := NewRegistry().SetTools(tools) // Checker that returns error should treat as false (tool excluded) checkerError := func(_ context.Context, _ string) (bool, error) { @@ -999,7 +1134,7 @@ func TestFeatureFlagResources(t *testing.T) { }, } - tsg := NewToolsetGroup(nil, resources, nil) + tsg := NewRegistry().SetResources(resources) // Without checker, resource with enable flag should be excluded available := tsg.AvailableResourceTemplates(context.Background()) @@ -1025,7 +1160,7 @@ func TestFeatureFlagPrompts(t *testing.T) { }, } - tsg := NewToolsetGroup(nil, nil, prompts) + tsg := NewRegistry().SetPrompts(prompts) // Without checker, prompt with enable flag should be excluded available := tsg.AvailablePrompts(context.Background()) @@ -1072,3 +1207,61 @@ func TestServerToolHandlerPanicOnNil(t *testing.T) { tool.Handler(nil) } + +// TestRegistryCopyCopiesAllFields ensures the copy() method stays in sync with the struct. +// If you add a new field to Registry, this test will fail until you update copy(). +func TestRegistryCopyCopiesAllFields(t *testing.T) { + // Create a Registry with non-zero/non-nil values for ALL fields + original := &Registry{ + tools: []ServerTool{mockTool("t1", "ts1", true)}, + resourceTemplates: []ServerResourceTemplate{{Template: mcp.ResourceTemplate{Name: "r1"}}}, + prompts: []ServerPrompt{{Prompt: mcp.Prompt{Name: "p1"}}}, + deprecatedAliases: map[string]string{"old": "new"}, + readOnly: true, + enabledToolsets: map[ToolsetID]bool{"ts1": true}, + additionalTools: map[string]bool{"extra": true}, + featureChecker: func(_ context.Context, _ string) (bool, error) { return true, nil }, + unrecognizedToolsets: []string{"unknown"}, + } + + copied := original.copy() + + // Verify all fields are copied correctly + if len(copied.tools) != len(original.tools) || (len(copied.tools) > 0 && copied.tools[0].Tool.Name != original.tools[0].Tool.Name) { + t.Error("tools not copied correctly") + } + if len(copied.resourceTemplates) != len(original.resourceTemplates) { + t.Error("resourceTemplates not copied correctly") + } + if len(copied.prompts) != len(original.prompts) { + t.Error("prompts not copied correctly") + } + if len(copied.deprecatedAliases) != len(original.deprecatedAliases) || copied.deprecatedAliases["old"] != "new" { + t.Error("deprecatedAliases not copied correctly") + } + if copied.readOnly != original.readOnly { + t.Error("readOnly not copied correctly") + } + if len(copied.enabledToolsets) != len(original.enabledToolsets) || !copied.enabledToolsets["ts1"] { + t.Error("enabledToolsets not copied correctly") + } + if len(copied.additionalTools) != len(original.additionalTools) || !copied.additionalTools["extra"] { + t.Error("additionalTools not copied correctly") + } + if copied.featureChecker == nil { + t.Error("featureChecker not copied correctly") + } + if len(copied.unrecognizedToolsets) != len(original.unrecognizedToolsets) || copied.unrecognizedToolsets[0] != "unknown" { + t.Error("unrecognizedToolsets not copied correctly") + } + + // Verify maps are deep copied (mutations don't affect original) + copied.enabledToolsets["ts2"] = true + if original.enabledToolsets["ts2"] { + t.Error("enabledToolsets should be deep copied, not shared") + } + copied.additionalTools["another"] = true + if original.additionalTools["another"] { + t.Error("additionalTools should be deep copied, not shared") + } +} From 3c44fddacb653d4827d3a75aea21ae9a00600886 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 15:32:21 +0100 Subject: [PATCH 11/27] refactor: rename toolsets package to registry with builder pattern - Rename pkg/toolsets to pkg/registry (better reflects its purpose) - Split monolithic toolsets.go into focused files: - registry.go: Core Registry struct and MCP methods - builder.go: Builder pattern for creating Registry instances - filters.go: All filtering logic (toolsets, read-only, feature flags) - resources.go: ServerResourceTemplate type - prompts.go: ServerPrompt type - errors.go: Error types - server_tool.go: ServerTool and ToolsetMetadata (existing) - Fix lint: Rename RegistryBuilder to Builder (avoid stuttering) - Update all imports across ~45 files This refactoring improves code organization and makes the registry's purpose clearer. The builder pattern provides a clean API: reg := registry.NewBuilder(). SetTools(tools). WithReadOnly(true). WithToolsets([]string{"repos"}). Build() --- cmd/github-mcp-server/generate_docs.go | 16 +- e2e/e2e_test.go | 2 +- internal/ghmcp/server.go | 197 ++-- pkg/github/actions.go | 34 +- pkg/github/actions_test.go | 48 +- pkg/github/code_scanning.go | 6 +- pkg/github/code_scanning_test.go | 8 +- pkg/github/context_tools.go | 8 +- pkg/github/context_tools_test.go | 279 +++--- pkg/github/dependabot.go | 6 +- pkg/github/dependabot_test.go | 4 +- pkg/github/dependencies.go | 112 ++- pkg/github/discussions.go | 10 +- pkg/github/discussions_test.go | 8 +- pkg/github/dynamic_tools.go | 28 +- pkg/github/gists.go | 10 +- pkg/github/gists_test.go | 16 +- pkg/github/git.go | 4 +- pkg/github/git_test.go | 4 +- pkg/github/issues.go | 28 +- pkg/github/issues_test.go | 68 +- pkg/github/labels.go | 8 +- pkg/github/labels_test.go | 12 +- pkg/github/notifications.go | 14 +- pkg/github/notifications_test.go | 24 +- pkg/github/projects.go | 20 +- pkg/github/projects_test.go | 36 +- pkg/github/prompts.go | 6 +- pkg/github/pullrequests.go | 30 +- pkg/github/pullrequests_test.go | 82 +- pkg/github/{toolset_group.go => registry.go} | 6 +- pkg/github/repositories.go | 38 +- pkg/github/repositories_test.go | 74 +- pkg/github/repository_resource.go | 32 +- pkg/github/repository_resource_test.go | 48 +- pkg/github/resources.go | 6 +- pkg/github/search.go | 10 +- pkg/github/search_test.go | 20 +- pkg/github/secret_scanning.go | 6 +- pkg/github/secret_scanning_test.go | 8 +- pkg/github/security_advisories.go | 10 +- pkg/github/security_advisories_test.go | 8 +- pkg/github/server_test.go | 59 +- pkg/github/tools.go | 54 +- pkg/github/tools_validation_test.go | 4 +- pkg/github/workflow_prompts.go | 6 +- pkg/registry/builder.go | 241 +++++ pkg/registry/errors.go | 41 + pkg/registry/filters.go | 246 +++++ pkg/registry/prompts.go | 26 + pkg/registry/registry.go | 343 +++++++ .../registry_test.go} | 363 +++----- pkg/registry/resources.go | 48 + pkg/{toolsets => registry}/server_tool.go | 2 +- pkg/toolsets/toolsets.go | 866 ------------------ 55 files changed, 1900 insertions(+), 1793 deletions(-) rename pkg/github/{toolset_group.go => registry.go} (78%) create mode 100644 pkg/registry/builder.go create mode 100644 pkg/registry/errors.go create mode 100644 pkg/registry/filters.go create mode 100644 pkg/registry/prompts.go create mode 100644 pkg/registry/registry.go rename pkg/{toolsets/toolsets_test.go => registry/registry_test.go} (75%) create mode 100644 pkg/registry/resources.go rename pkg/{toolsets => registry}/server_tool.go (99%) delete mode 100644 pkg/toolsets/toolsets.go diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index ddfcd10ba..8760c3f32 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -8,7 +8,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -50,8 +50,8 @@ func generateReadmeDocs(readmePath string) error { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group - stateless, no dependencies needed for doc generation - r := github.NewRegistry(t) + // Build registry - stateless, no dependencies needed for doc generation + r := github.NewRegistry(t).Build() // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(r) @@ -104,7 +104,7 @@ func generateRemoteServerDocs(docsPath string) error { return os.WriteFile(docsPath, []byte(updatedContent), 0600) //#nosec G306 } -func generateToolsetsDoc(r *toolsets.Registry) string { +func generateToolsetsDoc(r *registry.Registry) string { var buf strings.Builder // Add table header and separator @@ -123,7 +123,7 @@ func generateToolsetsDoc(r *toolsets.Registry) string { return strings.TrimSuffix(buf.String(), "\n") } -func generateToolsDoc(r *toolsets.Registry) string { +func generateToolsDoc(r *registry.Registry) string { // AllTools() returns tools sorted by toolset ID then tool name. // We iterate once, grouping by toolset as we encounter them. tools := r.AllTools() @@ -133,7 +133,7 @@ func generateToolsDoc(r *toolsets.Registry) string { var buf strings.Builder var toolBuf strings.Builder - var currentToolsetID toolsets.ToolsetID + var currentToolsetID registry.ToolsetID firstSection := true writeSection := func() { @@ -299,8 +299,8 @@ func generateRemoteToolsetsDoc() string { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group - stateless - r := github.NewRegistry(t) + // Build registry - stateless + r := github.NewRegistry(t).Build() // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index ad9ebb190..e286930a2 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -178,7 +178,7 @@ func setupMCPClient(t *testing.T, options ...clientOption) *mcp.ClientSession { // so that there is a shared setup mechanism, but let's wait till we feel more friction. enabledToolsets := opts.enabledToolsets if enabledToolsets == nil { - enabledToolsets = github.NewRegistry(translations.NullTranslationHelper).DefaultToolsetIDs() + enabledToolsets = github.NewRegistry(translations.NullTranslationHelper).Build().DefaultToolsetIDs() } ghServer, err := ghmcp.NewMCPServer(ghmcp.MCPServerConfig{ diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 67fcad4a7..1f924e482 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -18,7 +18,7 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -69,146 +69,153 @@ type MCPServerConfig struct { RepoAccessTTL *time.Duration } -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) - if err != nil { - return nil, fmt.Errorf("failed to parse API host: %w", err) - } +// githubClients holds all the GitHub API clients created for a server instance. +type githubClients struct { + rest *gogithub.Client + gql *githubv4.Client + gqlHTTP *http.Client // retained for middleware to modify transport + raw *raw.Client + repoAccess *lockdown.RepoAccessCache +} - // Construct our REST client +// createGitHubClients creates all the GitHub API clients needed by the server. +func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { + // Construct REST client restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) restClient.BaseURL = apiHost.baseRESTURL restClient.UploadURL = apiHost.uploadURL - // Construct our GraphQL client - // We're using NewEnterpriseClient here unconditionally as opposed to NewClient because we already - // did the necessary API host parsing so that github.com will return the correct URL anyway. + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ Transport: &bearerAuthTransport{ transport: http.DefaultTransport, token: cfg.Token, }, - } // We're going to wrap the Transport later in beforeInit - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - repoAccessOpts := []lockdown.RepoAccessOption{} - if cfg.RepoAccessTTL != nil { - repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL)) } + gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - repoAccessLogger := cfg.Logger.With("component", "lockdown") - repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger)) + // Create raw content client (shares REST client's HTTP transport) + rawClient := raw.NewClient(restClient, apiHost.rawURL) + + // Set up repo access cache for lockdown mode var repoAccessCache *lockdown.RepoAccessCache if cfg.LockdownMode { - repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) + opts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(cfg.Logger.With("component", "lockdown")), + } + if cfg.RepoAccessTTL != nil { + opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) + } + repoAccessCache = lockdown.GetInstance(gqlClient, opts...) } - // Determine enabled toolsets based on configuration: - // - nil means "use defaults" (unless dynamic mode without explicit toolsets) - // - empty slice means "no toolsets" (for dynamic mode to enable on demand) - // - explicit list means "use these toolsets" - var enabledToolsets []string + return &githubClients{ + rest: restClient, + gql: gqlClient, + gqlHTTP: gqlHTTPClient, + raw: rawClient, + repoAccess: repoAccessCache, + }, nil +} + +// resolveEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func resolveEnabledToolsets(cfg MCPServerConfig) []string { if cfg.EnabledToolsets != nil { - enabledToolsets = cfg.EnabledToolsets - } else if cfg.DynamicToolsets { - // Dynamic mode with no toolsets specified: start with no toolsets enabled - // so users can enable them on demand via the dynamic tools - enabledToolsets = []string{} + return cfg.EnabledToolsets } - // else: enabledToolsets stays nil, which means "use defaults" in WithToolsets - - // Generate instructions based on enabled toolsets - instructions := github.GenerateInstructions(enabledToolsets) - - getClient := func(_ context.Context) (*gogithub.Client, error) { - return restClient, nil // closing over client + if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} } + // nil means "use defaults" in WithToolsets + return nil +} - getGQLClient := func(_ context.Context) (*githubv4.Client, error) { - return gqlClient, nil // closing over client +func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { + apiHost, err := parseAPIHost(cfg.Host) + if err != nil { + return nil, fmt.Errorf("failed to parse API host: %w", err) } - getRawClient := func(ctx context.Context) (*raw.Client, error) { - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - return raw.NewClient(client, apiHost.rawURL), nil // closing over client + clients, err := createGitHubClients(cfg, apiHost) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } + enabledToolsets := resolveEnabledToolsets(cfg) + + // Create the MCP server ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{ - Instructions: instructions, - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(getClient), + Instructions: github.GenerateInstructions(enabledToolsets), + Logger: cfg.Logger, + CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { + return clients.rest, nil + }), }) // Add middlewares ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient)) - - // Create the dependencies struct for tool handlers - deps := github.ToolDependencies{ - GetClient: getClient, - GetGQLClient: getGQLClient, - GetRawClient: getRawClient, - RepoAccessCache: repoAccessCache, - T: cfg.Translator, - Flags: github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - ContentWindowSize: cfg.ContentWindowSize, - } - - // Create toolset group with all tools, resources, and prompts (stateless) - r := github.NewRegistry(cfg.Translator) - - // Clean tool names (WithTools will resolve any deprecated aliases) - enabledTools := github.CleanTools(cfg.EnabledTools) - - // Apply filters based on configuration - // - WithDeprecatedToolAliases: adds backward compatibility aliases - // - WithReadOnly: filters out write tools when true - // - WithToolsets: nil=defaults, empty=none, handles "all"/"default" keywords - // - WithTools: additional tools that bypass toolset filtering (additive, resolves aliases) - // - WithFeatureChecker: filters based on feature flags - filteredReg := r. - WithDeprecatedToolAliases(github.DeprecatedToolAliases). + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) + + // Create dependencies for tool handlers + deps := github.NewBaseDeps( + clients.rest, + clients.gql, + clients.raw, + clients.repoAccess, + cfg.Translator, + github.FeatureFlags{LockdownMode: cfg.LockdownMode}, + cfg.ContentWindowSize, + ) + + // Build and register the tool/resource/prompt registry + registry := github.NewRegistry(cfg.Translator). + WithDeprecatedAliases(github.DeprecatedToolAliases). WithReadOnly(cfg.ReadOnly). WithToolsets(enabledToolsets). - WithTools(enabledTools). - WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)) + WithTools(github.CleanTools(cfg.EnabledTools)). + WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)). + Build() - // Warn about unrecognized toolset names (likely typos) - if unrecognized := filteredReg.UnrecognizedToolsets(); len(unrecognized) > 0 { + if unrecognized := registry.UnrecognizedToolsets(); len(unrecognized) > 0 { fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) } - // Register all mcp functionality with the server - // Use background context for local server (no per-request actor context) - filteredReg.RegisterAll(context.Background(), ghServer, deps) + // Register GitHub tools/resources/prompts from the registry. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + registry.RegisterAll(context.Background(), ghServer, deps) - // Register dynamic toolset management if configured - // Dynamic tools get access to the filtered toolset group which tracks enabled state. - // ToolsForToolset() returns all tools for a toolset regardless of enabled status, - // so dynamic tools can enable any toolset at runtime. + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the registry, not part of the registry itself if cfg.DynamicToolsets { - dynamicDeps := github.DynamicToolDependencies{ - Server: ghServer, - Registry: filteredReg, - ToolDeps: deps, - T: cfg.Translator, - } - dynamicTools := github.DynamicTools(filteredReg) - for _, tool := range dynamicTools { - tool.RegisterFunc(ghServer, dynamicDeps) - } + registerDynamicTools(ghServer, registry, deps, cfg.Translator) } return ghServer, nil } +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, registry *registry.Registry, deps *github.BaseDeps, t translations.TranslationHelperFunc) { + dynamicDeps := github.DynamicToolDependencies{ + Server: server, + Registry: registry, + ToolDeps: deps, + T: t, + } + for _, tool := range github.DynamicTools(registry) { + tool.RegisterFunc(server, dynamicDeps) + } +} + // createFeatureChecker returns a FeatureFlagChecker that checks if a flag name // is present in the provided list of enabled features. For the local server, // this is populated from the --features CLI flag. -func createFeatureChecker(enabledFeatures []string) toolsets.FeatureFlagChecker { +func createFeatureChecker(enabledFeatures []string) registry.FeatureFlagChecker { // Build a set for O(1) lookup featureSet := make(map[string]bool, len(enabledFeatures)) for _, f := range enabledFeatures { diff --git a/pkg/github/actions.go b/pkg/github/actions.go index f29f75e99..584a23200 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -11,7 +11,7 @@ import ( "github.com/github/github-mcp-server/internal/profiler" buffer "github.com/github/github-mcp-server/pkg/buffer" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,7 +25,7 @@ const ( ) // ListWorkflows creates a tool to list workflows in a repository -func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflows(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -96,7 +96,7 @@ func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListWorkflowRuns creates a tool to list workflow runs for a specific workflow -func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowRuns(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -250,7 +250,7 @@ func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool } // RunWorkflow creates a tool to run an Actions workflow -func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RunWorkflow(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -362,7 +362,7 @@ func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetWorkflowRun creates a tool to get details of a specific workflow run -func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -430,7 +430,7 @@ func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetWorkflowRunLogs creates a tool to download logs for a specific workflow run -func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRunLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -508,7 +508,7 @@ func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerToo } // ListWorkflowJobs creates a tool to list jobs for a specific workflow run -func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowJobs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -608,7 +608,7 @@ func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool } // GetJobLogs creates a tool to download logs for a specific workflow job or efficiently get all failed job logs for a workflow run -func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetJobLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -706,10 +706,10 @@ func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { if failedOnly && runID > 0 { // Handle failed-only mode: get logs for all failed jobs in the workflow run - return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.ContentWindowSize) + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) } else if jobID > 0 { // Handle single job mode - return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.ContentWindowSize) + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) } return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil @@ -873,7 +873,7 @@ func downloadLogContent(ctx context.Context, logURL string, tailLines int, maxLi } // RerunWorkflowRun creates a tool to re-run an entire workflow run -func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RerunWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -948,7 +948,7 @@ func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool } // RerunFailedJobs creates a tool to re-run only the failed jobs in a workflow run -func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RerunFailedJobs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1023,7 +1023,7 @@ func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CancelWorkflowRun creates a tool to cancel a workflow run -func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CancelWorkflowRun(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1100,7 +1100,7 @@ func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool } // ListWorkflowRunArtifacts creates a tool to list artifacts for a workflow run -func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1180,7 +1180,7 @@ func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.Ser } // DownloadWorkflowRunArtifact creates a tool to download a workflow run artifact -func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1257,7 +1257,7 @@ func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets. } // DeleteWorkflowRunLogs creates a tool to delete logs for a workflow run -func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ @@ -1333,7 +1333,7 @@ func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.Server } // GetWorkflowRunUsage creates a tool to get usage metrics for a workflow run -func GetWorkflowRunUsage(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetWorkflowRunUsage(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataActions, mcp.Tool{ diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 09ab3b2cc..4d56f01aa 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -105,8 +105,8 @@ func Test_ListWorkflows(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -194,8 +194,8 @@ func Test_RunWorkflow(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -290,8 +290,8 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -398,8 +398,8 @@ func Test_CancelWorkflowRun(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -528,8 +528,8 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -618,8 +618,8 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -704,8 +704,8 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -808,8 +808,8 @@ func Test_GetWorkflowRunUsage(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -1072,8 +1072,8 @@ func Test_GetJobLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1136,8 +1136,8 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1188,8 +1188,8 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) @@ -1240,8 +1240,8 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { client := github.NewClient(mockedClient) toolDef := GetJobLogs(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, ContentWindowSize: 5000, } handler := toolDef.Handler(deps) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 888ad4fd2..8826e4cf6 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetCodeScanningAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataCodeSecurity, mcp.Tool{ @@ -94,7 +94,7 @@ func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerT ) } -func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListCodeScanningAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataCodeSecurity, mcp.Tool{ diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 5e56e6788..44c7a7e95 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -88,8 +88,8 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -220,8 +220,8 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index d5e0cfee9..837de00f7 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,7 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -37,7 +37,7 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetMe(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataContext, mcp.Tool{ @@ -111,7 +111,7 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTeams(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataContext, mcp.Tool{ @@ -210,7 +210,7 @@ func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTeamMembers(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataContext, mcp.Tool{ diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 0e28aad49..e9faefc40 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -3,7 +3,7 @@ package github import ( "context" "encoding/json" - "fmt" + "net/http" "testing" "time" @@ -48,7 +48,8 @@ func Test_GetMe(t *testing.T) { tests := []struct { name string - stubbedGetClientFn GetClientFn + mockedClient *http.Client + clientErr string // if set, GetClient returns this error requestArgs map[string]any expectToolError bool expectedUser *github.User @@ -56,12 +57,10 @@ func Test_GetMe(t *testing.T) { }{ { name: "successful get user", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, ), ), requestArgs: map[string]any{}, @@ -70,12 +69,10 @@ func Test_GetMe(t *testing.T) { }, { name: "successful get user with reason", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, ), ), requestArgs: map[string]any{ @@ -86,19 +83,17 @@ func Test_GetMe(t *testing.T) { }, { name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), + clientErr: "expected test error", requestArgs: map[string]any{}, expectToolError: true, expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + badRequestHandler("expected test failure"), ), ), requestArgs: map[string]any{}, @@ -109,8 +104,11 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetClient: tc.stubbedGetClientFn, + var deps ToolDependencies + if tc.clientErr != "" { + deps = stubDeps{clientFn: stubClientFnErr(tc.clientErr)} + } else { + deps = BaseDeps{Client: github.NewClient(tc.mockedClient)} } handler := serverTool.Handler(deps) @@ -223,49 +221,83 @@ func Test_GetTeams(t *testing.T) { }, }) + // Create GQL clients for different test scenarios - these are factory functions + // to ensure each test gets a fresh client + gqlClientForTestuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientForSpecificuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "specificuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoTeams := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + // Factory function for mock HTTP clients with user response + httpClientWithUser := func() *http.Client { + return mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUser, + mockUser, + ), + ) + } + + httpClientUserFails := func() *http.Client { + return mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUser, + badRequestHandler("expected test failure"), + ), + ) + } + tests := []struct { - name string - stubbedGetClientFn GetClientFn - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedTeamsCount int + name string + makeDeps func() ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedTeamsCount int }{ { name: "successful get teams", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientForTestuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 2, }, { - name: "successful get teams for specific user", - stubbedGetClientFn: nil, - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "specificuser", + name: "successful get teams for specific user", + makeDeps: func() ToolDependencies { + return BaseDeps{ + GQLClient: gqlClientForSpecificuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{ "user": "specificuser", @@ -275,62 +307,43 @@ func Test_GetTeams(t *testing.T) { }, { name: "no teams found", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientNoTeams(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 0, }, { - name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "failed to get GitHub client: expected test error", + name: "getting client fails", + makeDeps: func() ToolDependencies { + return stubDeps{clientFn: stubClientFnErr("expected test error")} + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), - ), - ), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "expected test failure", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientUserFails()), + } + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "expected test failure", }, { name: "getting GraphQL client fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") + makeDeps: func() ToolDependencies { + return stubDeps{ + clientFn: stubClientFnFromHTTP(httpClientWithUser()), + gqlClientFn: stubGQLClientFnErr("GraphQL client error"), + } }, requestArgs: map[string]any{}, expectToolError: true, @@ -340,11 +353,7 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetClient: tc.stubbedGetClientFn, - GetGQLClient: tc.stubbedGetGQLClientFn, - } - handler := serverTool.Handler(deps) + handler := serverTool.Handler(tc.makeDeps()) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), &request) @@ -422,26 +431,40 @@ func Test_GetTeamMembers(t *testing.T) { }, }) + // Create GQL clients for different test scenarios + gqlClientWithMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "testteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "emptyteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + tests := []struct { - name string - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedMembersCount int + name string + deps ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedMembersCount int }{ { name: "successful get team members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "testteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientWithMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -451,16 +474,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "team with no members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "emptyteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientNoMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "emptyteam", @@ -470,9 +484,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "getting GraphQL client fails", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") - }, + deps: stubDeps{gqlClientFn: stubGQLClientFnErr("GraphQL client error")}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -484,10 +496,7 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - deps := ToolDependencies{ - GetGQLClient: tc.stubbedGetGQLClientFn, - } - handler := serverTool.Handler(deps) + handler := serverTool.Handler(tc.deps) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), &request) diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index 1508d1382..daa2a124a 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDependabotAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDependabot, mcp.Tool{ @@ -94,7 +94,7 @@ func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerToo ) } -func ListDependabotAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDependabotAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDependabot, mcp.Tool{ diff --git a/pkg/github/dependabot_test.go b/pkg/github/dependabot_test.go index ace0eb07a..614c6f383 100644 --- a/pkg/github/dependabot_test.go +++ b/pkg/github/dependabot_test.go @@ -81,7 +81,7 @@ func Test_GetDependabotAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -232,7 +232,7 @@ func Test_ListDependabotAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 7dcc33f75..040e61883 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -1,53 +1,123 @@ package github import ( + "context" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/shurcooL/githubv4" ) -// ToolDependencies contains all dependencies that tool handlers might need. -// This is a properly-typed struct that lives in pkg/github to avoid circular -// dependencies. The toolsets package uses `any` for deps and tool handlers -// type-assert to this struct. -type ToolDependencies struct { +// ToolDependencies defines the interface for dependencies that tool handlers need. +// This is an interface to allow different implementations: +// - Local server: stores closures that create clients on demand +// - Remote server: can store pre-created clients per-request for efficiency +// +// The toolsets package uses `any` for deps and tool handlers type-assert to this interface. +type ToolDependencies interface { // GetClient returns a GitHub REST API client - GetClient GetClientFn + GetClient(ctx context.Context) (*gogithub.Client, error) // GetGQLClient returns a GitHub GraphQL client - GetGQLClient GetGQLClientFn + GetGQLClient(ctx context.Context) (*githubv4.Client, error) + + // GetRawClient returns a raw content client for GitHub + GetRawClient(ctx context.Context) (*raw.Client, error) - // GetRawClient returns a raw HTTP client for GitHub - GetRawClient raw.GetRawClientFn + // GetRepoAccessCache returns the lockdown mode repo access cache + GetRepoAccessCache() *lockdown.RepoAccessCache - // RepoAccessCache is the lockdown mode repo access cache - RepoAccessCache *lockdown.RepoAccessCache + // GetT returns the translation helper function + GetT() translations.TranslationHelperFunc - // T is the translation helper function - T translations.TranslationHelperFunc + // GetFlags returns feature flags + GetFlags() FeatureFlags + + // GetContentWindowSize returns the content window size for log truncation + GetContentWindowSize() int +} - // Flags are feature flags - Flags FeatureFlags +// BaseDeps is the standard implementation of ToolDependencies for the local server. +// It stores pre-created clients. The remote server can create its own struct +// implementing ToolDependencies with different client creation strategies. +type BaseDeps struct { + // Pre-created clients + Client *gogithub.Client + GQLClient *githubv4.Client + RawClient *raw.Client - // ContentWindowSize is the size of the content window for log truncation + // Static dependencies + RepoAccessCache *lockdown.RepoAccessCache + T translations.TranslationHelperFunc + Flags FeatureFlags ContentWindowSize int } +// NewBaseDeps creates a BaseDeps with the provided clients and configuration. +func NewBaseDeps( + client *gogithub.Client, + gqlClient *githubv4.Client, + rawClient *raw.Client, + repoAccessCache *lockdown.RepoAccessCache, + t translations.TranslationHelperFunc, + flags FeatureFlags, + contentWindowSize int, +) *BaseDeps { + return &BaseDeps{ + Client: client, + GQLClient: gqlClient, + RawClient: rawClient, + RepoAccessCache: repoAccessCache, + T: t, + Flags: flags, + ContentWindowSize: contentWindowSize, + } +} + +// GetClient implements ToolDependencies. +func (d BaseDeps) GetClient(_ context.Context) (*gogithub.Client, error) { + return d.Client, nil +} + +// GetGQLClient implements ToolDependencies. +func (d BaseDeps) GetGQLClient(_ context.Context) (*githubv4.Client, error) { + return d.GQLClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d BaseDeps) GetRawClient(_ context.Context) (*raw.Client, error) { + return d.RawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d BaseDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } + +// GetT implements ToolDependencies. +func (d BaseDeps) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d BaseDeps) GetFlags() FeatureFlags { return d.Flags } + +// GetContentWindowSize implements ToolDependencies. +func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } + // NewTool creates a ServerTool with fully-typed ToolDependencies and toolset metadata. // This helper isolates the type assertion from `any` to `ToolDependencies`, // so tool implementations remain fully typed without assertions scattered throughout. -func NewTool[In, Out any](toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) toolsets.ServerTool { - return toolsets.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { +func NewTool[In, Out any](toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) registry.ServerTool { + return registry.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[In, Out] { return handler(d.(ToolDependencies)) }) } // NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies and toolset metadata // for handlers that conform to mcp.ToolHandler directly. -func NewToolFromHandler(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) toolsets.ServerTool { - return toolsets.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { +func NewToolFromHandler(toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) registry.ServerTool { + return registry.NewServerToolFromHandler(tool, toolset, func(d any) mcp.ToolHandler { return handler(d.(ToolDependencies)) }) } diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 5bbdb2b5f..50364fa58 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -122,7 +122,7 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { return &BasicNoOrder{} } -func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDiscussions(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ @@ -276,7 +276,7 @@ func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDiscussion(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ @@ -381,7 +381,7 @@ func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetDiscussionComments(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ @@ -509,7 +509,7 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.Server ) } -func ListDiscussionCategories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListDiscussionCategories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataDiscussions, mcp.Tool{ diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 758c82200..73ae66748 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -447,7 +447,7 @@ func Test_ListDiscussions(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) @@ -559,7 +559,7 @@ func Test_GetDiscussion(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetDiscussion, vars, tc.response) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) reqParams := map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)} @@ -639,7 +639,7 @@ func Test_GetDiscussionComments(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetComments, vars, mockResponse) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) reqParams := map[string]interface{}{ @@ -791,7 +791,7 @@ func Test_ListDiscussionCategories(t *testing.T) { httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + deps := BaseDeps{GQLClient: gqlClient} handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index 93c24a07b..a749ecd1b 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -19,7 +19,7 @@ type DynamicToolDependencies struct { // Server is the MCP server to register tools with Server *mcp.Server // Registry contains all available tools that can be enabled dynamically - Registry *toolsets.Registry + Registry *registry.Registry // ToolDeps are the dependencies passed to tools when they are registered ToolDeps any // T is the translation helper function @@ -27,14 +27,14 @@ type DynamicToolDependencies struct { } // NewDynamicTool creates a ServerTool with fully-typed DynamicToolDependencies. -func NewDynamicTool(toolset toolsets.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) toolsets.ServerTool { - return toolsets.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { +func NewDynamicTool(toolset registry.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) registry.ServerTool { + return registry.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { return handler(d.(DynamicToolDependencies)) }) } // toolsetIDsEnum returns the list of toolset IDs as an enum for JSON Schema. -func toolsetIDsEnum(r *toolsets.Registry) []any { +func toolsetIDsEnum(r *registry.Registry) []any { toolsetIDs := r.ToolsetIDs() result := make([]any, len(toolsetIDs)) for i, id := range toolsetIDs { @@ -44,10 +44,10 @@ func toolsetIDsEnum(r *toolsets.Registry) []any { } // DynamicTools returns the tools for dynamic toolset management. -// These tools allow runtime discovery and enablement of toolsets. +// These tools allow runtime discovery and enablement of registry. // The r parameter provides the available toolset IDs for JSON Schema enums. -func DynamicTools(r *toolsets.Registry) []toolsets.ServerTool { - return []toolsets.ServerTool{ +func DynamicTools(r *registry.Registry) []registry.ServerTool { + return []registry.ServerTool{ ListAvailableToolsets(), GetToolsetsTools(r), EnableToolset(r), @@ -55,7 +55,7 @@ func DynamicTools(r *toolsets.Registry) []toolsets.ServerTool { } // EnableToolset creates a tool that enables a toolset at runtime. -func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { +func EnableToolset(r *registry.Registry) registry.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -84,7 +84,7 @@ func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { return utils.NewToolResultError(err.Error()), nil, nil } - toolsetID := toolsets.ToolsetID(toolsetName) + toolsetID := registry.ToolsetID(toolsetName) if !deps.Registry.HasToolset(toolsetID) { return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil @@ -109,8 +109,8 @@ func EnableToolset(r *toolsets.Registry) toolsets.ServerTool { ) } -// ListAvailableToolsets creates a tool that lists all available toolsets. -func ListAvailableToolsets() toolsets.ServerTool { +// ListAvailableToolsets creates a tool that lists all available registry. +func ListAvailableToolsets() registry.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -153,7 +153,7 @@ func ListAvailableToolsets() toolsets.ServerTool { } // GetToolsetsTools creates a tool that lists all tools in a specific toolset. -func GetToolsetsTools(r *toolsets.Registry) toolsets.ServerTool { +func GetToolsetsTools(r *registry.Registry) registry.ServerTool { return NewDynamicTool( ToolsetMetadataDynamic, mcp.Tool{ @@ -182,7 +182,7 @@ func GetToolsetsTools(r *toolsets.Registry) toolsets.ServerTool { return utils.NewToolResultError(err.Error()), nil, nil } - toolsetID := toolsets.ToolsetID(toolsetName) + toolsetID := registry.ToolsetID(toolsetName) if !deps.Registry.HasToolset(toolsetID) { return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil diff --git a/pkg/github/gists.go b/pkg/github/gists.go index 03e5e1bc8..7b8313f37 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,7 +7,7 @@ import ( "io" "net/http" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListGists(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ @@ -104,7 +104,7 @@ func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetGist creates a tool to get the content of a gist -func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ @@ -163,7 +163,7 @@ func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateGist creates a tool to create a new gist -func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ @@ -267,7 +267,7 @@ func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdateGist(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGists, mcp.Tool{ diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index 44b294eb6..7c6f69833 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -158,8 +158,8 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -275,8 +275,8 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -421,8 +421,8 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -580,8 +580,8 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/git.go b/pkg/github/git.go index e619afc34..4755e2eb0 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -7,7 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -38,7 +38,7 @@ type TreeResponse struct { } // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. -func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetRepositoryTree(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataGit, mcp.Tool{ diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 69442e312..c971995b2 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -148,8 +148,8 @@ func Test_GetRepositoryTree(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 1d0e3b2d5..3e449f8c5 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -11,8 +11,8 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/sanitize" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -230,7 +230,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // IssueRead creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func IssueRead(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -310,13 +310,13 @@ Options are: switch method { case "get": - result, err := GetIssue(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, deps.Flags) + result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_labels": result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) @@ -545,7 +545,7 @@ func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, } // ListIssueTypes creates a tool to list defined issue types for an organization. This can be used to understand supported issue type values for creating or updating issues. -func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListIssueTypes(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -602,7 +602,7 @@ func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddIssueComment(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -687,7 +687,7 @@ func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { } // SubIssueWrite creates a tool to add a sub-issue to a parent issue. -func SubIssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SubIssueWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -916,7 +916,7 @@ func ReprioritizeSubIssue(ctx context.Context, client *github.Client, owner stri } // SearchIssues creates a tool to search for issues. -func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchIssues(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -979,7 +979,7 @@ func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { } // IssueWrite creates a tool to create a new or update an existing issue in a GitHub repository. -func IssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func IssueWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataIssues, mcp.Tool{ @@ -1338,7 +1338,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 } // ListIssues creates a tool to list and filter repository issues -func ListIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListIssues(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1606,7 +1606,7 @@ func (d *mvpDescription) String() string { return sb.String() } -func AssignCopilotToIssue(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AssignCopilotToIssue(t translations.TranslationHelperFunc) registry.ServerTool { description := mvpDescription{ summary: "Assign Copilot to a specific issue in a GitHub repository.", outcomes: []string{ @@ -1805,8 +1805,8 @@ func parseISOTimestamp(timestamp string) (time.Time, error) { return time.Time{}, fmt.Errorf("invalid ISO 8601 timestamp: %s (supported formats: YYYY-MM-DDThh:mm:ssZ or YYYY-MM-DD)", timestamp) } -func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) toolsets.ServerPrompt { - return toolsets.NewServerPrompt( +func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) registry.ServerPrompt { + return registry.NewServerPrompt( ToolsetMetadataIssues, mcp.Prompt{ Name: "AssignCodingAgent", diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c832f031a..4c686cc57 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -330,9 +330,9 @@ func Test_GetIssue(t *testing.T) { } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: cache, Flags: flags, } @@ -447,8 +447,8 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -781,8 +781,8 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -952,9 +952,9 @@ func Test_CreateIssue(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -1268,8 +1268,8 @@ func Test_ListIssues(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -1769,9 +1769,9 @@ func Test_UpdateIssue(t *testing.T) { // Setup clients with mocks restClient := github.NewClient(tc.mockedRESTClient) gqlClient := githubv4.NewClient(tc.mockedGQLClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(restClient), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -2016,9 +2016,9 @@ func Test_GetIssueComments(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 15*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: cache, Flags: flags, } @@ -2136,9 +2136,9 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -2560,8 +2560,8 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2791,8 +2791,8 @@ func Test_AddSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3035,9 +3035,9 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -3275,8 +3275,8 @@ func Test_RemoveSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3564,8 +3564,8 @@ func Test_ReprioritizeSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3698,8 +3698,8 @@ func Test_ListIssueTypes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/labels.go b/pkg/github/labels.go index a98468fae..90bf8066b 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -7,7 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -16,7 +16,7 @@ import ( ) // GetLabel retrieves a specific label by name from a GitHub repository -func GetLabel(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetLabel(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetLabels, mcp.Tool{ @@ -111,7 +111,7 @@ func GetLabel(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListLabels lists labels from a repository -func ListLabels(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetLabels, mcp.Tool{ @@ -203,7 +203,7 @@ func ListLabels(t translations.TranslationHelperFunc) toolsets.ServerTool { } // LabelWrite handles create, update, and delete operations for GitHub labels -func LabelWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func LabelWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetLabels, mcp.Tool{ diff --git a/pkg/github/labels_test.go b/pkg/github/labels_test.go index 980395ff7..fa646e884 100644 --- a/pkg/github/labels_test.go +++ b/pkg/github/labels_test.go @@ -114,8 +114,8 @@ func TestGetLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -212,8 +212,8 @@ func TestListLabels(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -463,8 +463,8 @@ func TestWriteLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 4eb2d7b5b..569bef002 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -10,7 +10,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,7 +25,7 @@ const ( ) // ListNotifications creates a tool to list notifications for the current user. -func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListNotifications(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -163,7 +163,7 @@ func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool } // DismissNotification creates a tool to mark a notification as read/done. -func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DismissNotification(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -246,7 +246,7 @@ func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTo } // MarkAllNotificationsRead creates a tool to mark all notifications as read. -func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func MarkAllNotificationsRead(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -339,7 +339,7 @@ func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.Ser } // GetNotificationDetails creates a tool to get details for a specific notification. -func GetNotificationDetails(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetNotificationDetails(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -409,7 +409,7 @@ const ( ) // ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) -func ManageNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ManageNotificationSubscription(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ @@ -506,7 +506,7 @@ const ( ) // ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) -func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataNotifications, mcp.Tool{ diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 0a330c316..f730654db 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -125,8 +125,8 @@ func Test_ListNotifications(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -258,8 +258,8 @@ func Test_ManageNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -421,8 +421,8 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -563,8 +563,8 @@ func Test_DismissNotification(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -688,8 +688,8 @@ func Test_MarkAllNotificationsRead(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -772,8 +772,8 @@ func Test_GetNotificationDetails(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/projects.go b/pkg/github/projects.go index a12aca7be..4c0b5c09d 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -9,7 +9,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -25,7 +25,7 @@ const ( MaxProjectsPerPage = 50 ) -func ListProjects(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListProjects(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -144,7 +144,7 @@ func ListProjects(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func GetProject(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetProject(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -234,7 +234,7 @@ func GetProject(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func ListProjectFields(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListProjectFields(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -342,7 +342,7 @@ func ListProjectFields(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetProjectField(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetProjectField(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -436,7 +436,7 @@ func GetProjectField(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func ListProjectItems(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListProjectItems(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -574,7 +574,7 @@ func ListProjectItems(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -682,7 +682,7 @@ func GetProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func AddProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -795,7 +795,7 @@ func AddProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { ) } -func UpdateProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdateProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ @@ -909,7 +909,7 @@ func UpdateProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func DeleteProjectItem(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteProjectItem(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataProjects, mcp.Tool{ diff --git a/pkg/github/projects_test.go b/pkg/github/projects_test.go index 0c2e2ab52..67ecd8800 100644 --- a/pkg/github/projects_test.go +++ b/pkg/github/projects_test.go @@ -141,8 +141,8 @@ func Test_ListProjects(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -280,8 +280,8 @@ func Test_GetProject(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -432,8 +432,8 @@ func Test_ListProjectFields(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -592,8 +592,8 @@ func Test_GetProjectField(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -798,8 +798,8 @@ func Test_ListProjectItems(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -995,8 +995,8 @@ func Test_GetProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -1224,8 +1224,8 @@ func Test_AddProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -1509,8 +1509,8 @@ func Test_UpdateProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -1676,8 +1676,8 @@ func Test_DeleteProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/prompts.go b/pkg/github/prompts.go index 82d7bf514..229902d90 100644 --- a/pkg/github/prompts.go +++ b/pkg/github/prompts.go @@ -1,14 +1,14 @@ package github import ( - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" ) // AllPrompts returns all prompts with their embedded toolset metadata. // Prompt functions return ServerPrompt directly with toolset info. -func AllPrompts(t translations.TranslationHelperFunc) []toolsets.ServerPrompt { - return []toolsets.ServerPrompt{ +func AllPrompts(t translations.TranslationHelperFunc) []registry.ServerPrompt { + return []registry.ServerPrompt{ // Issue prompts AssignCodingAgentPrompt(t), IssueToFixWorkflowPrompt(t), diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 229e20e57..4e7ede755 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -15,14 +15,14 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/sanitize" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PullRequestRead(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -99,7 +99,7 @@ Possible options: switch method { case "get": - result, err := GetPullRequest(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_diff": result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) @@ -111,13 +111,13 @@ Possible options: result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) return result, nil, err case "get_review_comments": - result, err := GetPullRequestReviewComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + result, err := GetPullRequestReviewComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil @@ -390,7 +390,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreatePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -531,7 +531,7 @@ func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdatePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -826,7 +826,7 @@ func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListPullRequests(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -970,7 +970,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { +func MergePullRequest(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1079,7 +1079,7 @@ func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool } // SearchPullRequests creates a tool to search for pull requests. -func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchPullRequests(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1142,7 +1142,7 @@ func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerToo } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UpdatePullRequestBranch(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1247,7 +1247,7 @@ type PullRequestReviewWriteParams struct { CommitID *string } -func PullRequestReviewWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PullRequestReviewWrite(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1559,7 +1559,7 @@ func DeletePendingPullRequestReview(ctx context.Context, client *githubv4.Client } // AddCommentToPendingReview creates a tool to add a comment to a pull request review. -func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.ServerTool { +func AddCommentToPendingReview(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1752,7 +1752,7 @@ func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.Se // RequestCopilotReview creates a tool to request a Copilot review for a pull request. // Note that this tool will not work on GHES where this feature is unsupported. In future, we should not expose this // tool if the configured host does not support it. -func RequestCopilotReview(t translations.TranslationHelperFunc) toolsets.ServerTool { +func RequestCopilotReview(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 7531edf6d..a22245700 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -105,9 +105,9 @@ func Test_GetPullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, RepoAccessCache: stubRepoAccessCache(gqlClient, 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -370,9 +370,9 @@ func Test_UpdatePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -561,9 +561,9 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) serverTool := UpdatePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(restClient), - GetGQLClient: stubGetGQLClientFn(gqlClient), + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, } handler := serverTool.Handler(deps) @@ -693,8 +693,8 @@ func Test_ListPullRequests(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := ListPullRequests(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -817,8 +817,8 @@ func Test_MergePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := MergePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1123,8 +1123,8 @@ func Test_SearchPullRequests(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := SearchPullRequests(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1276,8 +1276,8 @@ func Test_GetPullRequestFiles(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -1451,8 +1451,8 @@ func Test_GetPullRequestStatus(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } @@ -1589,8 +1589,8 @@ func Test_UpdatePullRequestBranch(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1763,8 +1763,8 @@ func Test_GetPullRequestComments(t *testing.T) { cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: cache, Flags: flags, } @@ -1951,8 +1951,8 @@ func Test_GetPullRequestReviews(t *testing.T) { cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: cache, Flags: flags, } @@ -2115,8 +2115,8 @@ func Test_CreatePullRequest(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := CreatePullRequest(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2331,8 +2331,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2447,8 +2447,8 @@ func Test_RequestCopilotReview(t *testing.T) { client := github.NewClient(tc.mockedClient) serverTool := RequestCopilotReview(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2641,8 +2641,8 @@ func TestCreatePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2824,8 +2824,8 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -2929,8 +2929,8 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -3028,8 +3028,8 @@ func TestDeletePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetGQLClient: stubGetGQLClientFn(client), + deps := BaseDeps{ + GQLClient: client, } handler := serverTool.Handler(deps) @@ -3119,8 +3119,8 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) serverTool := PullRequestRead(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), } diff --git a/pkg/github/toolset_group.go b/pkg/github/registry.go similarity index 78% rename from pkg/github/toolset_group.go rename to pkg/github/registry.go index 7330e08d3..88795ee1e 100644 --- a/pkg/github/toolset_group.go +++ b/pkg/github/registry.go @@ -1,7 +1,7 @@ package github import ( - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" ) @@ -10,8 +10,8 @@ import ( // This function is stateless - no dependencies are captured. // Handlers are generated on-demand during registration via RegisterAll(ctx, server, deps). // The "default" keyword in WithToolsets will expand to toolsets marked with Default: true. -func NewRegistry(t translations.TranslationHelperFunc) *toolsets.Registry { - return toolsets.NewRegistry(). +func NewRegistry(t translations.TranslationHelperFunc) *registry.Builder { + return registry.NewBuilder(). SetTools(AllTools(t)). SetResources(AllResources(t)). SetPrompts(AllPrompts(t)) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 81e5c3a8c..854da679a 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -11,7 +11,7 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -19,7 +19,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetCommit(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -118,7 +118,7 @@ func GetCommit(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListCommits(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -227,7 +227,7 @@ func ListCommits(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListBranches creates a tool to list branches in a GitHub repository. -func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListBranches(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -315,7 +315,7 @@ func ListBranches(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateOrUpdateFile(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -443,7 +443,7 @@ func CreateOrUpdateFile(t translations.TranslationHelperFunc) toolsets.ServerToo } // CreateRepository creates a tool to create a new GitHub repository. -func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -550,7 +550,7 @@ func CreateRepository(t translations.TranslationHelperFunc) toolsets.ServerTool } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetFileContents(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -770,7 +770,7 @@ func GetFileContents(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ForkRepository creates a tool to fork a repository. -func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ForkRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -869,7 +869,7 @@ func ForkRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { // unlike how the endpoint backing the create_or_update_files tool does. This appears to be a quirk of the API. // The approach implemented here gets automatic commit signing when used with either the github-actions user or as an app, // both of which suit an LLM well. -func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { +func DeleteFile(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1055,7 +1055,7 @@ func DeleteFile(t translations.TranslationHelperFunc) toolsets.ServerTool { } // CreateBranch creates a tool to create a new branch. -func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { +func CreateBranch(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1169,7 +1169,7 @@ func CreateBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { } // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { +func PushFiles(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1354,7 +1354,7 @@ func PushFiles(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListTags creates a tool to list tags in a GitHub repository. -func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListTags(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1434,7 +1434,7 @@ func ListTags(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetTag creates a tool to get details about a specific tag in a GitHub repository. -func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetTag(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1533,7 +1533,7 @@ func GetTag(t translations.TranslationHelperFunc) toolsets.ServerTool { } // ListReleases creates a tool to list releases in a GitHub repository. -func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListReleases(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1609,7 +1609,7 @@ func ListReleases(t translations.TranslationHelperFunc) toolsets.ServerTool { } // GetLatestRelease creates a tool to get the latest release in a GitHub repository. -func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetLatestRelease(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1675,7 +1675,7 @@ func GetLatestRelease(t translations.TranslationHelperFunc) toolsets.ServerTool ) } -func GetReleaseByTag(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetReleaseByTag(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataRepos, mcp.Tool{ @@ -1889,7 +1889,7 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner } // ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user. -func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListStarredRepositories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataStargazers, mcp.Tool{ @@ -2022,7 +2022,7 @@ func ListStarredRepositories(t translations.TranslationHelperFunc) toolsets.Serv } // StarRepository creates a tool to star a repository. -func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func StarRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataStargazers, mcp.Tool{ @@ -2088,7 +2088,7 @@ func StarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { } // UnstarRepository creates a tool to unstar a repository. -func UnstarRepository(t translations.TranslationHelperFunc) toolsets.ServerTool { +func UnstarRepository(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataStargazers, mcp.Tool{ diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 949686d92..55f0866cb 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -286,9 +286,9 @@ func Test_GetFileContents(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, &url.URL{Scheme: "https", Host: "raw.example.com", Path: "/"}) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), - GetRawClient: stubGetRawClientFn(mockRawClient), + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, } handler := serverTool.Handler(deps) @@ -410,8 +410,8 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -606,8 +606,8 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -738,8 +738,8 @@ func Test_GetCommit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -964,8 +964,8 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1143,8 +1143,8 @@ func Test_CreateOrUpdateFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1329,8 +1329,8 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1668,8 +1668,8 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -1789,8 +1789,8 @@ func Test_ListBranches(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create mock client mockClient := github.NewClient(mock.NewMockedHTTPClient(tt.mockResponses...)) - deps := ToolDependencies{ - GetClient: stubGetClientFn(mockClient), + deps := BaseDeps{ + Client: mockClient, } handler := serverTool.Handler(deps) @@ -1977,8 +1977,8 @@ func Test_DeleteFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2104,8 +2104,8 @@ func Test_ListTags(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2264,8 +2264,8 @@ func Test_GetTag(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -2377,8 +2377,8 @@ func Test_ListReleases(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -2468,8 +2468,8 @@ func Test_GetLatestRelease(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -2616,8 +2616,8 @@ func Test_GetReleaseByTag(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3128,8 +3128,8 @@ func Test_ListStarredRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3229,8 +3229,8 @@ func Test_StarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -3320,8 +3320,8 @@ func Test_UnstarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 6dbbe90ec..af001af6f 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -14,7 +14,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -30,8 +30,8 @@ var ( ) // GetRepositoryResourceContent defines the resource template for getting repository content. -func GetRepositoryResourceContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content", @@ -43,8 +43,8 @@ func GetRepositoryResourceContent(t translations.TranslationHelperFunc) toolsets } // GetRepositoryResourceBranchContent defines the resource template for getting repository content for a branch. -func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_branch", @@ -56,8 +56,8 @@ func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) to } // GetRepositoryResourceCommitContent defines the resource template for getting repository content for a commit. -func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_commit", @@ -69,8 +69,8 @@ func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) to } // GetRepositoryResourceTagContent defines the resource template for getting repository content for a tag. -func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_tag", @@ -82,8 +82,8 @@ func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) tools } // GetRepositoryResourcePrContent defines the resource template for getting repository content for a pull request. -func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) toolsets.ServerResourceTemplate { - return toolsets.NewServerResourceTemplate( +func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) registry.ServerResourceTemplate { + return registry.NewServerResourceTemplate( ToolsetMetadataRepos, mcp.ResourceTemplate{ Name: "repository_content_pr", @@ -95,15 +95,15 @@ func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) toolse } // repositoryResourceContentsHandlerFunc returns a ResourceHandlerFunc that creates handlers on-demand. -func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) toolsets.ResourceHandlerFunc { +func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) registry.ResourceHandlerFunc { return func(deps any) mcp.ResourceHandler { d := deps.(ToolDependencies) - return RepositoryResourceContentsHandler(d.GetClient, d.GetRawClient, resourceURITemplate) + return RepositoryResourceContentsHandler(d, resourceURITemplate) } } // RepositoryResourceContentsHandler returns a handler function for repository content requests. -func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { +func RepositoryResourceContentsHandler(deps ToolDependencies, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { return func(ctx context.Context, request *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { // Match the URI to extract parameters uriValues := resourceURITemplate.Match(request.Params.URI) @@ -157,7 +157,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G prNumber := uriValues.Get("prNumber").String() if prNumber != "" { // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) + githubClient, err := deps.GetClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -177,7 +177,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G if path == "" || strings.HasSuffix(path, "/") { return nil, fmt.Errorf("directories are not supported: %s", path) } - rawClient, err := getRawClient(ctx) + rawClient, err := deps.GetRawClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub raw content client: %w", err) diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index f938a57f5..99c06cdd6 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -27,7 +27,7 @@ func Test_repositoryResourceContents(t *testing.T) { name string mockedClient *http.Client uri string - handlerFn func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler + handlerFn func(deps ToolDependencies) mcp.ResourceHandler expectedResponseType resourceResponseType expectError string expectedResult *mcp.ReadResourceResult @@ -45,8 +45,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo:///repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "owner is required", @@ -64,8 +64,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner//refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "repo is required", @@ -83,8 +83,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/data.png", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeBlob, expectedResult: &mcp.ReadResourceResult{ @@ -107,8 +107,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -133,8 +133,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/pkg/github/actions.go", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -157,8 +157,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -181,8 +181,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/tags/v1.0.0/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceTagContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -205,8 +205,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/sha/abc123/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceCommitContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -237,8 +237,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/refs/pull/42/head/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourcePrContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -260,8 +260,8 @@ func Test_repositoryResourceContents(t *testing.T) { ), ), uri: "repo://owner/repo/contents/nonexistent.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn) mcp.ResourceHandler { - return RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "404 Not Found", @@ -272,7 +272,11 @@ func Test_repositoryResourceContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, base) - handler := tc.handlerFn(stubGetClientFn(client), stubGetRawClientFn(mockRawClient)) + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, + } + handler := tc.handlerFn(deps) request := &mcp.ReadResourceRequest{ Params: &mcp.ReadResourceParams{ diff --git a/pkg/github/resources.go b/pkg/github/resources.go index 253c4bc11..6acf5eb6a 100644 --- a/pkg/github/resources.go +++ b/pkg/github/resources.go @@ -1,14 +1,14 @@ package github import ( - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" ) // AllResources returns all resource templates with their embedded toolset metadata. // Resource definitions are stateless - handlers are generated on-demand during registration. -func AllResources(t translations.TranslationHelperFunc) []toolsets.ServerResourceTemplate { - return []toolsets.ServerResourceTemplate{ +func AllResources(t translations.TranslationHelperFunc) []registry.ServerResourceTemplate { + return []registry.ServerResourceTemplate{ // Repository resources GetRepositoryResourceContent(t), GetRepositoryResourceBranchContent(t), diff --git a/pkg/github/search.go b/pkg/github/search.go index 730435eba..4b35f3f0d 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -17,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchRepositories(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -167,7 +167,7 @@ func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerToo } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchCode(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -351,7 +351,7 @@ func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchUsers(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -392,7 +392,7 @@ func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { +func SearchOrgs(t translations.TranslationHelperFunc) registry.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 41d12df1b..707b55349 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -134,8 +134,8 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -209,8 +209,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { client := github.NewClient(mockedClient) serverTool := SearchRepositories(translations.NullTranslationHelper) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -358,8 +358,8 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -558,8 +558,8 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) @@ -733,8 +733,8 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := serverTool.Handler(deps) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 7e842ded1..e840072b0 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -8,7 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetSecretScanningAlert(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecretProtection, mcp.Tool{ @@ -94,7 +94,7 @@ func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.Serve ) } -func ListSecretScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListSecretScanningAlerts(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecretProtection, mcp.Tool{ diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 83de16409..b63617a46 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -87,8 +87,8 @@ func Test_GetSecretScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) @@ -228,8 +228,8 @@ func Test_ListSecretScanningAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{ - GetClient: stubGetClientFn(client), + deps := BaseDeps{ + Client: client, } handler := toolDef.Handler(deps) diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index cf507d17a..28acb8156 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -7,7 +7,7 @@ import ( "io" "net/http" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,7 +15,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ @@ -207,7 +207,7 @@ func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets ) } -func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ @@ -312,7 +312,7 @@ func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) tool ) } -func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.ServerTool { +func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ @@ -370,7 +370,7 @@ func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.Se ) } -func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { +func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( ToolsetMetadataSecurityAdvisories, mcp.Tool{ diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go index 16506a3e8..3970949ec 100644 --- a/pkg/github/security_advisories_test.go +++ b/pkg/github/security_advisories_test.go @@ -103,7 +103,7 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -224,7 +224,7 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) // Create call request @@ -372,7 +372,7 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) @@ -517,7 +517,7 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - deps := ToolDependencies{GetClient: stubGetClientFn(client)} + deps := BaseDeps{Client: client} handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 2e9ab43a3..a59cd9a93 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -11,32 +11,67 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" ) -func stubGetClientFn(client *github.Client) GetClientFn { - return func(_ context.Context) (*github.Client, error) { - return client, nil +// stubDeps is a test helper that implements ToolDependencies with configurable behavior. +// Use this when you need to test error paths or when you need closure-based client creation. +type stubDeps struct { + clientFn func(context.Context) (*github.Client, error) + gqlClientFn func(context.Context) (*githubv4.Client, error) + rawClientFn func(context.Context) (*raw.Client, error) + + repoAccessCache *lockdown.RepoAccessCache + t translations.TranslationHelperFunc + flags FeatureFlags + contentWindowSize int +} + +func (s stubDeps) GetClient(ctx context.Context) (*github.Client, error) { + if s.clientFn != nil { + return s.clientFn(ctx) } + return nil, nil } -func stubGetClientFromHTTPFn(client *http.Client) GetClientFn { +func (s stubDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) { + if s.gqlClientFn != nil { + return s.gqlClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { + if s.rawClientFn != nil { + return s.rawClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache } +func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } +func (s stubDeps) GetFlags() FeatureFlags { return s.flags } +func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } + +// Helper functions to create stub client functions for error testing +func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return github.NewClient(client), nil + return github.NewClient(httpClient), nil } } -func stubGetClientFnErr(err string) GetClientFn { +func stubClientFnErr(errMsg string) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return nil, errors.New(err) + return nil, errors.New(errMsg) } } -func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { +func stubGQLClientFnErr(errMsg string) func(context.Context) (*githubv4.Client, error) { return func(_ context.Context) (*githubv4.Client, error) { - return client, nil + return nil, errors.New(errMsg) } } @@ -51,12 +86,6 @@ func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { } } -func stubGetRawClientFn(client *raw.Client) raw.GetRawClientFn { - return func(_ context.Context) (*raw.Client, error) { - return client, nil - } -} - func badRequestHandler(msg string) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { structuredErrorResponse := github.ErrorResponse{ diff --git a/pkg/github/tools.go b/pkg/github/tools.go index dd2ad4ff4..714397b5d 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" @@ -17,96 +17,96 @@ type GetGQLClientFn func(context.Context) (*githubv4.Client, error) // Toolset metadata constants - these define all available toolsets and their descriptions. // Tools use these constants to declare which toolset they belong to. var ( - ToolsetMetadataAll = toolsets.ToolsetMetadata{ + ToolsetMetadataAll = registry.ToolsetMetadata{ ID: "all", Description: "Special toolset that enables all available toolsets", } - ToolsetMetadataDefault = toolsets.ToolsetMetadata{ + ToolsetMetadataDefault = registry.ToolsetMetadata{ ID: "default", Description: "Special toolset that enables the default toolset configuration. When no toolsets are specified, this is the set that is enabled", } - ToolsetMetadataContext = toolsets.ToolsetMetadata{ + ToolsetMetadataContext = registry.ToolsetMetadata{ ID: "context", Description: "Tools that provide context about the current user and GitHub context you are operating in", Default: true, } - ToolsetMetadataRepos = toolsets.ToolsetMetadata{ + ToolsetMetadataRepos = registry.ToolsetMetadata{ ID: "repos", Description: "GitHub Repository related tools", Default: true, } - ToolsetMetadataGit = toolsets.ToolsetMetadata{ + ToolsetMetadataGit = registry.ToolsetMetadata{ ID: "git", Description: "GitHub Git API related tools for low-level Git operations", } - ToolsetMetadataIssues = toolsets.ToolsetMetadata{ + ToolsetMetadataIssues = registry.ToolsetMetadata{ ID: "issues", Description: "GitHub Issues related tools", Default: true, } - ToolsetMetadataPullRequests = toolsets.ToolsetMetadata{ + ToolsetMetadataPullRequests = registry.ToolsetMetadata{ ID: "pull_requests", Description: "GitHub Pull Request related tools", Default: true, } - ToolsetMetadataUsers = toolsets.ToolsetMetadata{ + ToolsetMetadataUsers = registry.ToolsetMetadata{ ID: "users", Description: "GitHub User related tools", Default: true, } - ToolsetMetadataOrgs = toolsets.ToolsetMetadata{ + ToolsetMetadataOrgs = registry.ToolsetMetadata{ ID: "orgs", Description: "GitHub Organization related tools", } - ToolsetMetadataActions = toolsets.ToolsetMetadata{ + ToolsetMetadataActions = registry.ToolsetMetadata{ ID: "actions", Description: "GitHub Actions workflows and CI/CD operations", } - ToolsetMetadataCodeSecurity = toolsets.ToolsetMetadata{ + ToolsetMetadataCodeSecurity = registry.ToolsetMetadata{ ID: "code_security", Description: "Code security related tools, such as GitHub Code Scanning", } - ToolsetMetadataSecretProtection = toolsets.ToolsetMetadata{ + ToolsetMetadataSecretProtection = registry.ToolsetMetadata{ ID: "secret_protection", Description: "Secret protection related tools, such as GitHub Secret Scanning", } - ToolsetMetadataDependabot = toolsets.ToolsetMetadata{ + ToolsetMetadataDependabot = registry.ToolsetMetadata{ ID: "dependabot", Description: "Dependabot tools", } - ToolsetMetadataNotifications = toolsets.ToolsetMetadata{ + ToolsetMetadataNotifications = registry.ToolsetMetadata{ ID: "notifications", Description: "GitHub Notifications related tools", } - ToolsetMetadataExperiments = toolsets.ToolsetMetadata{ + ToolsetMetadataExperiments = registry.ToolsetMetadata{ ID: "experiments", Description: "Experimental features that are not considered stable yet", } - ToolsetMetadataDiscussions = toolsets.ToolsetMetadata{ + ToolsetMetadataDiscussions = registry.ToolsetMetadata{ ID: "discussions", Description: "GitHub Discussions related tools", } - ToolsetMetadataGists = toolsets.ToolsetMetadata{ + ToolsetMetadataGists = registry.ToolsetMetadata{ ID: "gists", Description: "GitHub Gist related tools", } - ToolsetMetadataSecurityAdvisories = toolsets.ToolsetMetadata{ + ToolsetMetadataSecurityAdvisories = registry.ToolsetMetadata{ ID: "security_advisories", Description: "Security advisories related tools", } - ToolsetMetadataProjects = toolsets.ToolsetMetadata{ + ToolsetMetadataProjects = registry.ToolsetMetadata{ ID: "projects", Description: "GitHub Projects related tools", } - ToolsetMetadataStargazers = toolsets.ToolsetMetadata{ + ToolsetMetadataStargazers = registry.ToolsetMetadata{ ID: "stargazers", Description: "GitHub Stargazers related tools", } - ToolsetMetadataDynamic = toolsets.ToolsetMetadata{ + ToolsetMetadataDynamic = registry.ToolsetMetadata{ ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", } - ToolsetLabels = toolsets.ToolsetMetadata{ + ToolsetLabels = registry.ToolsetMetadata{ ID: "labels", Description: "GitHub Labels related tools", } @@ -114,8 +114,8 @@ var ( // AllTools returns all tools with their embedded toolset metadata. // Tool functions return ServerTool directly with toolset info. -func AllTools(t translations.TranslationHelperFunc) []toolsets.ServerTool { - return []toolsets.ServerTool{ +func AllTools(t translations.TranslationHelperFunc) []registry.ServerTool { + return []registry.ServerTool{ // Context tools GetMe(t), GetTeams(t), @@ -263,7 +263,7 @@ func ToStringPtr(s string) *string { // GenerateToolsetsHelp generates the help text for the toolsets flag func GenerateToolsetsHelp() string { // Get toolset group to derive defaults and available toolsets - r := NewRegistry(stubTranslator) + r := NewRegistry(stubTranslator).Build() // Format default tools from metadata defaultIDs := r.DefaultToolsetIDs() @@ -333,7 +333,7 @@ func AddDefaultToolset(result []string) []string { result = RemoveToolset(result, string(ToolsetMetadataDefault.ID)) // Get default toolset IDs from the Registry - r := NewRegistry(stubTranslator) + r := NewRegistry(stubTranslator).Build() for _, id := range r.DefaultToolsetIDs() { if !seen[string(id)] { result = append(result, string(id)) diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go index d53243b42..1e74c2518 100644 --- a/pkg/github/tools_validation_test.go +++ b/pkg/github/tools_validation_test.go @@ -3,7 +3,7 @@ package github import ( "testing" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -153,7 +153,7 @@ func TestAllToolsHaveHandlerFunc(t *testing.T) { // TestToolsetMetadataConsistency ensures tools in the same toolset have consistent descriptions func TestToolsetMetadataConsistency(t *testing.T) { tools := AllTools(stubTranslation) - toolsetDescriptions := make(map[toolsets.ToolsetID]string) + toolsetDescriptions := make(map[registry.ToolsetID]string) for _, tool := range tools { id := tool.Toolset.ID diff --git a/pkg/github/workflow_prompts.go b/pkg/github/workflow_prompts.go index cf972020d..603a98087 100644 --- a/pkg/github/workflow_prompts.go +++ b/pkg/github/workflow_prompts.go @@ -4,14 +4,14 @@ import ( "context" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/registry" "github.com/github/github-mcp-server/pkg/translations" "github.com/modelcontextprotocol/go-sdk/mcp" ) // IssueToFixWorkflowPrompt provides a guided workflow for creating an issue and then generating a PR to fix it -func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) toolsets.ServerPrompt { - return toolsets.NewServerPrompt( +func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) registry.ServerPrompt { + return registry.NewServerPrompt( ToolsetMetadataIssues, mcp.Prompt{ Name: "issue_to_fix_workflow", diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go new file mode 100644 index 000000000..a8720a4f4 --- /dev/null +++ b/pkg/registry/builder.go @@ -0,0 +1,241 @@ +package registry + +import ( + "strings" +) + +// Builder builds a Registry with the specified configuration. +// Use NewBuilder to create a builder, chain configuration methods, +// then call Build() to create the final Registry. +// +// Example: +// +// reg := NewBuilder(). +// SetTools(tools). +// SetResources(resources). +// SetPrompts(prompts). +// WithDeprecatedAliases(aliases). +// WithReadOnly(true). +// WithToolsets([]string{"repos", "issues"}). +// WithFeatureChecker(checker). +// Build() +type Builder struct { + tools []ServerTool + resourceTemplates []ServerResourceTemplate + prompts []ServerPrompt + deprecatedAliases map[string]string + + // Configuration options (processed at Build time) + readOnly bool + toolsetIDs []string // raw input, processed at Build() + toolsetIDsIsNil bool // tracks if nil was passed (nil = defaults) + additionalTools []string // raw input, processed at Build() + featureChecker FeatureFlagChecker +} + +// NewBuilder creates a new Builder. +func NewBuilder() *Builder { + return &Builder{ + deprecatedAliases: make(map[string]string), + toolsetIDsIsNil: true, // default to nil (use defaults) + } +} + +// SetTools sets the tools for the registry. Returns self for chaining. +func (b *Builder) SetTools(tools []ServerTool) *Builder { + b.tools = tools + return b +} + +// SetResources sets the resource templates for the registry. Returns self for chaining. +func (b *Builder) SetResources(resources []ServerResourceTemplate) *Builder { + b.resourceTemplates = resources + return b +} + +// SetPrompts sets the prompts for the registry. Returns self for chaining. +func (b *Builder) SetPrompts(prompts []ServerPrompt) *Builder { + b.prompts = prompts + return b +} + +// WithDeprecatedAliases adds deprecated tool name aliases that map to canonical names. +// Returns self for chaining. +func (b *Builder) WithDeprecatedAliases(aliases map[string]string) *Builder { + for oldName, newName := range aliases { + b.deprecatedAliases[oldName] = newName + } + return b +} + +// WithReadOnly sets whether only read-only tools should be available. +// When true, write tools are filtered out. Returns self for chaining. +func (b *Builder) WithReadOnly(readOnly bool) *Builder { + b.readOnly = readOnly + return b +} + +// WithToolsets specifies which toolsets should be enabled. +// Special keywords: +// - "all": enables all toolsets +// - "default": expands to toolsets marked with Default: true in their metadata +// +// Input strings are trimmed of whitespace and duplicates are removed. +// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets +// (useful for dynamic toolsets mode where tools are enabled on demand). +// Returns self for chaining. +func (b *Builder) WithToolsets(toolsetIDs []string) *Builder { + b.toolsetIDs = toolsetIDs + b.toolsetIDsIsNil = toolsetIDs == nil + return b +} + +// WithTools specifies additional tools that bypass toolset filtering. +// These tools are additive - they will be included even if their toolset is not enabled. +// Read-only filtering still applies to these tools. +// Deprecated tool aliases are automatically resolved to their canonical names during Build(). +// Returns self for chaining. +func (b *Builder) WithTools(toolNames []string) *Builder { + b.additionalTools = toolNames + return b +} + +// WithFeatureChecker sets the feature flag checker function. +// The checker receives a context (for actor extraction) and feature flag name, +// returns (enabled, error). If error occurs, it will be logged and treated as false. +// If checker is nil, all feature flag checks return false. +// Returns self for chaining. +func (b *Builder) WithFeatureChecker(checker FeatureFlagChecker) *Builder { + b.featureChecker = checker + return b +} + +// Build creates the final Registry with all configuration applied. +// This processes toolset filtering, tool name resolution, and sets up +// the registry for use. The returned Registry is ready for use with +// AvailableTools(), RegisterAll(), etc. +func (b *Builder) Build() *Registry { + r := &Registry{ + tools: b.tools, + resourceTemplates: b.resourceTemplates, + prompts: b.prompts, + deprecatedAliases: b.deprecatedAliases, + readOnly: b.readOnly, + featureChecker: b.featureChecker, + } + + // Process toolsets + r.enabledToolsets, r.unrecognizedToolsets = b.processToolsets() + + // Process additional tools (resolve aliases) + if len(b.additionalTools) > 0 { + r.additionalTools = make(map[string]bool, len(b.additionalTools)) + for _, name := range b.additionalTools { + // Resolve deprecated aliases to canonical names + if canonical, isAlias := b.deprecatedAliases[name]; isAlias { + r.additionalTools[canonical] = true + } else { + r.additionalTools[name] = true + } + } + } + + return r +} + +// processToolsets processes the toolsetIDs configuration and returns: +// - enabledToolsets map (nil means all enabled) +// - unrecognizedToolsets list for warnings +func (b *Builder) processToolsets() (map[ToolsetID]bool, []string) { + // Build a set of valid toolset IDs for validation + validIDs := make(map[ToolsetID]bool) + for _, t := range b.tools { + validIDs[t.Toolset.ID] = true + } + for _, r := range b.resourceTemplates { + validIDs[r.Toolset.ID] = true + } + for _, p := range b.prompts { + validIDs[p.Toolset.ID] = true + } + + toolsetIDs := b.toolsetIDs + + // Check for "all" keyword - enables all toolsets + for _, id := range toolsetIDs { + if strings.TrimSpace(id) == "all" { + return nil, nil // nil means all enabled + } + } + + // nil means use defaults, empty slice means no toolsets + if b.toolsetIDsIsNil { + toolsetIDs = []string{"default"} + } + + // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized + seen := make(map[ToolsetID]bool) + expanded := make([]ToolsetID, 0, len(toolsetIDs)) + var unrecognized []string + + for _, id := range toolsetIDs { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + continue + } + if trimmed == "default" { + for _, defaultID := range b.defaultToolsetIDs() { + if !seen[defaultID] { + seen[defaultID] = true + expanded = append(expanded, defaultID) + } + } + } else { + tsID := ToolsetID(trimmed) + if !seen[tsID] { + seen[tsID] = true + expanded = append(expanded, tsID) + // Track if this toolset doesn't exist + if !validIDs[tsID] { + unrecognized = append(unrecognized, trimmed) + } + } + } + } + + if len(expanded) == 0 { + return make(map[ToolsetID]bool), unrecognized + } + + enabledToolsets := make(map[ToolsetID]bool, len(expanded)) + for _, id := range expanded { + enabledToolsets[id] = true + } + return enabledToolsets, unrecognized +} + +// defaultToolsetIDs returns toolset IDs marked as Default in their metadata. +func (b *Builder) defaultToolsetIDs() []ToolsetID { + seen := make(map[ToolsetID]bool) + for i := range b.tools { + if b.tools[i].Toolset.Default { + seen[b.tools[i].Toolset.ID] = true + } + } + for i := range b.resourceTemplates { + if b.resourceTemplates[i].Toolset.Default { + seen[b.resourceTemplates[i].Toolset.ID] = true + } + } + for i := range b.prompts { + if b.prompts[i].Toolset.Default { + seen[b.prompts[i].Toolset.ID] = true + } + } + + ids := make([]ToolsetID, 0, len(seen)) + for id := range seen { + ids = append(ids, id) + } + return ids +} diff --git a/pkg/registry/errors.go b/pkg/registry/errors.go new file mode 100644 index 000000000..75cbb6f82 --- /dev/null +++ b/pkg/registry/errors.go @@ -0,0 +1,41 @@ +package registry + +import "fmt" + +// ToolsetDoesNotExistError is returned when a toolset is not found. +type ToolsetDoesNotExistError struct { + Name string +} + +func (e *ToolsetDoesNotExistError) Error() string { + return fmt.Sprintf("toolset %s does not exist", e.Name) +} + +func (e *ToolsetDoesNotExistError) Is(target error) bool { + if target == nil { + return false + } + if _, ok := target.(*ToolsetDoesNotExistError); ok { + return true + } + return false +} + +// NewToolsetDoesNotExistError creates a new ToolsetDoesNotExistError. +func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { + return &ToolsetDoesNotExistError{Name: name} +} + +// ToolDoesNotExistError is returned when a tool is not found. +type ToolDoesNotExistError struct { + Name string +} + +func (e *ToolDoesNotExistError) Error() string { + return fmt.Sprintf("tool %s does not exist", e.Name) +} + +// NewToolDoesNotExistError creates a new ToolDoesNotExistError. +func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { + return &ToolDoesNotExistError{Name: name} +} diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go new file mode 100644 index 000000000..2d46aab98 --- /dev/null +++ b/pkg/registry/filters.go @@ -0,0 +1,246 @@ +package registry + +import ( + "context" + "fmt" + "os" + "sort" +) + +// FeatureFlagChecker is a function that checks if a feature flag is enabled. +// The context can be used to extract actor/user information for flag evaluation. +// Returns (enabled, error). If error occurs, the caller should log and treat as false. +type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) + +// isToolsetEnabled checks if a toolset is enabled based on current filters. +func (r *Registry) isToolsetEnabled(toolsetID ToolsetID) bool { + // Check enabled toolsets filter + if r.enabledToolsets != nil { + return r.enabledToolsets[toolsetID] + } + return true +} + +// checkFeatureFlag checks a feature flag using the feature checker. +// Returns false if checker is nil or returns an error (errors are logged). +func (r *Registry) checkFeatureFlag(ctx context.Context, flagName string) bool { + if r.featureChecker == nil || flagName == "" { + return false + } + enabled, err := r.featureChecker(ctx, flagName) + if err != nil { + fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + return false + } + return enabled +} + +// isFeatureFlagAllowed checks if an item passes feature flag filtering. +// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled +// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled +func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { + // Check enable flag - item requires this flag to be on + if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { + return false + } + // Check disable flag - item is excluded if this flag is on + if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { + return false + } + return true +} + +// isToolEnabled checks if a specific tool is enabled based on current filters. +func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { + // Check read-only filter first (applies to all tools) + if r.readOnly && !tool.IsReadOnly() { + return false + } + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { + return false + } + // Check if tool is in additionalTools (bypasses toolset filter) + if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { + return true + } + // Check toolset filter + if !r.isToolsetEnabled(tool.Toolset.ID) { + return false + } + return true +} + +// AvailableTools returns the tools that pass all current filters, +// sorted deterministically by toolset ID, then tool name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailableTools(ctx context.Context) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + if r.isToolEnabled(ctx, tool) { + result = append(result, *tool) + } + } + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableResourceTemplates returns resource templates that pass all current filters, +// sorted deterministically by toolset ID, then template name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { + var result []ServerResourceTemplate + for i := range r.resourceTemplates { + res := &r.resourceTemplates[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(res.Toolset.ID) { + result = append(result, *res) + } + } + + // Sort deterministically: by toolset ID, then by template name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Template.Name < result[j].Template.Name + }) + + return result +} + +// AvailablePrompts returns prompts that pass all current filters, +// sorted deterministically by toolset ID, then prompt name. +// The context is used for feature flag evaluation. +func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { + var result []ServerPrompt + for i := range r.prompts { + prompt := &r.prompts[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(prompt.Toolset.ID) { + result = append(result, *prompt) + } + } + + // Sort deterministically: by toolset ID, then by prompt name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Prompt.Name < result[j].Prompt.Name + }) + + return result +} + +// filterToolsByName returns tools matching the given name, checking deprecated aliases. +// Returns from the current tools slice (respects existing filter chain). +func (r *Registry) filterToolsByName(name string) []ServerTool { + // First check for exact match + for i := range r.tools { + if r.tools[i].Tool.Name == name { + return []ServerTool{r.tools[i]} + } + } + // Check if name is a deprecated alias + if canonical, isAlias := r.deprecatedAliases[name]; isAlias { + for i := range r.tools { + if r.tools[i].Tool.Name == canonical { + return []ServerTool{r.tools[i]} + } + } + } + return []ServerTool{} +} + +// filterResourcesByURI returns resource templates matching the given URI pattern. +func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { + for i := range r.resourceTemplates { + // Check if URI matches the template pattern (exact match on URITemplate string) + if r.resourceTemplates[i].Template.URITemplate == uri { + return []ServerResourceTemplate{r.resourceTemplates[i]} + } + } + return []ServerResourceTemplate{} +} + +// filterPromptsByName returns prompts matching the given name. +func (r *Registry) filterPromptsByName(name string) []ServerPrompt { + for i := range r.prompts { + if r.prompts[i].Prompt.Name == name { + return []ServerPrompt{r.prompts[i]} + } + } + return []ServerPrompt{} +} + +// ToolsForToolset returns all tools belonging to a specific toolset. +// This method bypasses the toolset enabled filter (for dynamic toolset registration), +// but still respects the read-only filter. +func (r *Registry) ToolsForToolset(toolsetID ToolsetID) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + // Only check read-only filter, not toolset enabled filter + if tool.Toolset.ID == toolsetID { + if r.readOnly && !tool.IsReadOnly() { + continue + } + result = append(result, *tool) + } + } + + // Sort by tool name for deterministic order + sort.Slice(result, func(i, j int) bool { + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// IsToolsetEnabled checks if a toolset is currently enabled based on filters. +func (r *Registry) IsToolsetEnabled(toolsetID ToolsetID) bool { + return r.isToolsetEnabled(toolsetID) +} + +// EnableToolset marks a toolset as enabled in this group. +// This is used by dynamic toolset management to track which toolsets have been enabled. +func (r *Registry) EnableToolset(toolsetID ToolsetID) { + if r.enabledToolsets == nil { + // nil means all enabled, so nothing to do + return + } + r.enabledToolsets[toolsetID] = true +} + +// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. +// Returns all toolset IDs if no filter is set. +func (r *Registry) EnabledToolsetIDs() []ToolsetID { + if r.enabledToolsets == nil { + return r.ToolsetIDs() + } + + ids := make([]ToolsetID, 0, len(r.enabledToolsets)) + for id := range r.enabledToolsets { + if r.HasToolset(id) { + ids = append(ids, id) + } + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} diff --git a/pkg/registry/prompts.go b/pkg/registry/prompts.go new file mode 100644 index 000000000..02dda6c9c --- /dev/null +++ b/pkg/registry/prompts.go @@ -0,0 +1,26 @@ +package registry + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ServerPrompt pairs a prompt with its toolset metadata. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler mcp.PromptHandler + // Toolset identifies which toolset this prompt belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt + // to be available. If set and the flag is not enabled, the prompt is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt + // to be omitted. Used to disable prompts when a feature flag is on. + FeatureFlagDisable string +} + +// NewServerPrompt creates a new ServerPrompt with toolset metadata. +func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { + return ServerPrompt{ + Prompt: prompt, + Handler: handler, + Toolset: toolset, + } +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go new file mode 100644 index 000000000..7a73a6e5b --- /dev/null +++ b/pkg/registry/registry.go @@ -0,0 +1,343 @@ +package registry + +import ( + "context" + "fmt" + "os" + "slices" + "sort" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Registry holds a collection of tools, resources, and prompts with filtering applied. +// Create a Registry using Builder: +// +// reg := NewBuilder(). +// SetTools(tools). +// WithReadOnly(true). +// WithToolsets([]string{"repos"}). +// Build() +// +// The Registry is configured at build time and provides: +// - Filtered access to tools/resources/prompts via Available* methods +// - Deterministic ordering for documentation generation +// - Lazy dependency injection during registration via RegisterAll() +// - Runtime toolset enabling for dynamic toolsets mode +type Registry struct { + // tools holds all tools in this group + tools []ServerTool + // resourceTemplates holds all resource templates in this group + resourceTemplates []ServerResourceTemplate + // prompts holds all prompts in this group + prompts []ServerPrompt + // deprecatedAliases maps old tool names to new canonical names + deprecatedAliases map[string]string + + // Filters - these control what's returned by Available* methods + // readOnly when true filters out write tools + readOnly bool + // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets + // when nil, all toolsets are enabled + enabledToolsets map[ToolsetID]bool + // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) + // These are additive - a tool is included if it matches toolset filters OR is in this set + additionalTools map[string]bool + // featureChecker when non-nil, checks if a feature flag is enabled. + // Takes context and flag name, returns (enabled, error). If error, log and treat as false. + // If checker is nil, all flag checks return false. + featureChecker FeatureFlagChecker + // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets + unrecognizedToolsets []string +} + +// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't +// match any registered toolsets. This is useful for warning users about typos. +func (r *Registry) UnrecognizedToolsets() []string { + return r.unrecognizedToolsets +} + +// MCP method constants for use with ForMCPRequest. +const ( + MCPMethodInitialize = "initialize" + MCPMethodToolsList = "tools/list" + MCPMethodToolsCall = "tools/call" + MCPMethodResourcesList = "resources/list" + MCPMethodResourcesRead = "resources/read" + MCPMethodResourcesTemplatesList = "resources/templates/list" + MCPMethodPromptsList = "prompts/list" + MCPMethodPromptsGet = "prompts/get" +) + +// ForMCPRequest returns a Registry optimized for a specific MCP request. +// This is designed for servers that create a new instance per request (like the remote server), +// allowing them to only register the items needed for that specific request rather than all ~90 tools. +// +// Parameters: +// - method: The MCP method being called (use MCP* constants) +// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) +// +// Returns a new Registry containing only the items relevant to the request: +// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) +// - MCPMethodToolsList: All available tools (no resources/prompts) +// - MCPMethodToolsCall: Only the named tool +// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) +// - MCPMethodResourcesRead: Only the named resource template +// - MCPMethodPromptsList: All available prompts (no tools/resources) +// - MCPMethodPromptsGet: Only the named prompt +// - Unknown methods: Empty (no items registered) +// +// All existing filters (read-only, toolsets, etc.) still apply to the returned items. +func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { + // Create a shallow copy with shared filter settings + result := &Registry{ + tools: r.tools, + resourceTemplates: r.resourceTemplates, + prompts: r.prompts, + deprecatedAliases: r.deprecatedAliases, + readOnly: r.readOnly, + enabledToolsets: r.enabledToolsets, // shared, not modified + additionalTools: r.additionalTools, // shared, not modified + featureChecker: r.featureChecker, + unrecognizedToolsets: r.unrecognizedToolsets, + } + + // Helper to clear all item types + clearAll := func() { + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + } + + switch method { + case MCPMethodInitialize: + clearAll() + case MCPMethodToolsList: + result.resourceTemplates, result.prompts = nil, nil + case MCPMethodToolsCall: + result.resourceTemplates, result.prompts = nil, nil + if itemName != "" { + result.tools = r.filterToolsByName(itemName) + } + case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: + result.tools, result.prompts = nil, nil + case MCPMethodResourcesRead: + result.tools, result.prompts = nil, nil + if itemName != "" { + result.resourceTemplates = r.filterResourcesByURI(itemName) + } + case MCPMethodPromptsList: + result.tools, result.resourceTemplates = nil, nil + case MCPMethodPromptsGet: + result.tools, result.resourceTemplates = nil, nil + if itemName != "" { + result.prompts = r.filterPromptsByName(itemName) + } + default: + clearAll() + } + + return result +} + +// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. +func (r *Registry) ToolsetIDs() []ToolsetID { + seen := make(map[ToolsetID]bool) + for i := range r.tools { + seen[r.tools[i].Toolset.ID] = true + } + for i := range r.resourceTemplates { + seen[r.resourceTemplates[i].Toolset.ID] = true + } + for i := range r.prompts { + seen[r.prompts[i].Toolset.ID] = true + } + + ids := make([]ToolsetID, 0, len(seen)) + for id := range seen { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. +// The IDs are returned in sorted order for deterministic output. +func (r *Registry) DefaultToolsetIDs() []ToolsetID { + seen := make(map[ToolsetID]bool) + for i := range r.tools { + if r.tools[i].Toolset.Default { + seen[r.tools[i].Toolset.ID] = true + } + } + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Toolset.Default { + seen[r.resourceTemplates[i].Toolset.ID] = true + } + } + for i := range r.prompts { + if r.prompts[i].Toolset.Default { + seen[r.prompts[i].Toolset.ID] = true + } + } + + ids := make([]ToolsetID, 0, len(seen)) + for id := range seen { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. +func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { + descriptions := make(map[ToolsetID]string) + for i := range r.tools { + t := &r.tools[i] + if t.Toolset.Description != "" { + descriptions[t.Toolset.ID] = t.Toolset.Description + } + } + for i := range r.resourceTemplates { + res := &r.resourceTemplates[i] + if res.Toolset.Description != "" { + descriptions[res.Toolset.ID] = res.Toolset.Description + } + } + for i := range r.prompts { + p := &r.prompts[i] + if p.Toolset.Description != "" { + descriptions[p.Toolset.ID] = p.Toolset.Description + } + } + return descriptions +} + +// RegisterTools registers all available tools with the server using the provided dependencies. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { + for _, tool := range r.AvailableTools(ctx) { + tool.RegisterFunc(s, deps) + } +} + +// RegisterResourceTemplates registers all available resource templates with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { + for _, res := range r.AvailableResourceTemplates(ctx) { + s.AddResourceTemplate(&res.Template, res.Handler(deps)) + } +} + +// RegisterPrompts registers all available prompts with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterPrompts(ctx context.Context, s *mcp.Server) { + for _, prompt := range r.AvailablePrompts(ctx) { + s.AddPrompt(&prompt.Prompt, prompt.Handler) + } +} + +// RegisterAll registers all available tools, resources, and prompts with the server. +// The context is used for feature flag evaluation. +func (r *Registry) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { + r.RegisterTools(ctx, s, deps) + r.RegisterResourceTemplates(ctx, s, deps) + r.RegisterPrompts(ctx, s) +} + +// ResolveToolAliases resolves deprecated tool aliases to their canonical names. +// It logs a warning to stderr for each deprecated alias that is resolved. +// Returns: +// - resolved: tool names with aliases replaced by canonical names +// - aliasesUsed: map of oldName → newName for each alias that was resolved +func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { + resolved = make([]string, 0, len(toolNames)) + aliasesUsed = make(map[string]string) + for _, toolName := range toolNames { + if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { + fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) + aliasesUsed[toolName] = canonicalName + resolved = append(resolved, canonicalName) + } else { + resolved = append(resolved, toolName) + } + } + return resolved, aliasesUsed +} + +// FindToolByName searches all tools for one matching the given name. +// Returns the tool, its toolset ID, and an error if not found. +// This searches ALL tools regardless of filters. +func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { + for i := range r.tools { + tool := &r.tools[i] + if tool.Tool.Name == toolName { + return tool, tool.Toolset.ID, nil + } + } + return nil, "", NewToolDoesNotExistError(toolName) +} + +// HasToolset checks if any tool/resource/prompt belongs to the given toolset. +func (r *Registry) HasToolset(toolsetID ToolsetID) bool { + for i := range r.tools { + if r.tools[i].Toolset.ID == toolsetID { + return true + } + } + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Toolset.ID == toolsetID { + return true + } + } + for i := range r.prompts { + if r.prompts[i].Toolset.ID == toolsetID { + return true + } + } + return false +} + +// AllTools returns all tools without any filtering, sorted deterministically. +func (r *Registry) AllTools() []ServerTool { + result := slices.Clone(r.tools) + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableToolsets returns the unique toolsets that have tools, in sorted order. +// This is the ordered intersection of toolsets with reality - only toolsets that +// actually contain tools are returned, sorted by toolset ID. +// Optional exclude parameter filters out specific toolset IDs from the result. +func (r *Registry) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { + tools := r.AllTools() + if len(tools) == 0 { + return nil + } + + // Build exclude set for O(1) lookup + excludeSet := make(map[ToolsetID]bool, len(exclude)) + for _, id := range exclude { + excludeSet[id] = true + } + + var result []ToolsetMetadata + var lastID ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastID { + lastID = tool.Toolset.ID + if !excludeSet[lastID] { + result = append(result, tool.Toolset) + } + } + } + return result +} diff --git a/pkg/toolsets/toolsets_test.go b/pkg/registry/registry_test.go similarity index 75% rename from pkg/toolsets/toolsets_test.go rename to pkg/registry/registry_test.go index 0d1c35e2e..4518ed8bf 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/registry/registry_test.go @@ -1,4 +1,4 @@ -package toolsets +package registry import ( "context" @@ -65,15 +65,15 @@ func mockTool(name string, toolsetID string, readOnly bool) ServerTool { } func TestNewRegistryEmpty(t *testing.T) { - tsg := NewRegistry() - if len(tsg.tools) != 0 { - t.Fatalf("Expected tools to be empty, got %d items", len(tsg.tools)) + reg := NewBuilder().Build() + if len(reg.AvailableTools(context.Background())) != 0 { + t.Fatalf("Expected tools to be empty") } - if len(tsg.resourceTemplates) != 0 { - t.Fatalf("Expected resourceTemplates to be empty, got %d items", len(tsg.resourceTemplates)) + if len(reg.AvailableResourceTemplates(context.Background())) != 0 { + t.Fatalf("Expected resourceTemplates to be empty") } - if len(tsg.prompts) != 0 { - t.Fatalf("Expected prompts to be empty, got %d items", len(tsg.prompts)) + if len(reg.AvailablePrompts(context.Background())) != 0 { + t.Fatalf("Expected prompts to be empty") } } @@ -84,10 +84,10 @@ func TestNewRegistryWithTools(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) + reg := NewBuilder().SetTools(tools).Build() - if len(tsg.tools) != 3 { - t.Errorf("Expected 3 tools, got %d", len(tsg.tools)) + if len(reg.AllTools()) != 3 { + t.Errorf("Expected 3 tools, got %d", len(reg.AllTools())) } } @@ -98,8 +98,8 @@ func TestAvailableTools_NoFilters(t *testing.T) { mockTool("tool_c", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - available := tsg.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 3 { t.Fatalf("Expected 3 available tools, got %d", len(available)) @@ -120,29 +120,22 @@ func TestWithReadOnly(t *testing.T) { mockTool("write_tool", "toolset1", false), } - tsg := NewRegistry().SetTools(tools) - - // Original should have both tools - allTools := tsg.AvailableTools(context.Background()) + // Build without read-only - should have both tools + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := reg.AvailableTools(context.Background()) if len(allTools) != 2 { - t.Fatalf("Expected 2 tools in original, got %d", len(allTools)) + t.Fatalf("Expected 2 tools without read-only, got %d", len(allTools)) } - // Read-only should filter out write tools - readOnlyTsg := tsg.WithReadOnly(true) - readOnlyTools := readOnlyTsg.AvailableTools(context.Background()) + // Build with read-only - should filter out write tools + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + readOnlyTools := readOnlyReg.AvailableTools(context.Background()) if len(readOnlyTools) != 1 { t.Fatalf("Expected 1 tool in read-only, got %d", len(readOnlyTools)) } if readOnlyTools[0].Tool.Name != "read_tool" { t.Errorf("Expected read_tool, got %s", readOnlyTools[0].Tool.Name) } - - // Original should still have both (immutability test) - allTools = tsg.AvailableTools(context.Background()) - if len(allTools) != 2 { - t.Fatalf("Original was mutated! Expected 2 tools, got %d", len(allTools)) - } } func TestWithToolsets(t *testing.T) { @@ -152,10 +145,15 @@ func TestWithToolsets(t *testing.T) { mockTool("tool3", "toolset3", true), } - tsg := NewRegistry().SetTools(tools) + // Build with all toolsets + allReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := allReg.AvailableTools(context.Background()) + if len(allTools) != 3 { + t.Fatalf("Expected 3 tools without filter, got %d", len(allTools)) + } - // Filter to specific toolsets - filteredReg := tsg.WithToolsets([]string{"toolset1", "toolset3"}) + // Build with specific toolsets + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset3"}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { @@ -170,12 +168,6 @@ func TestWithToolsets(t *testing.T) { if !toolNames["tool1"] || !toolNames["tool3"] { t.Errorf("Expected tool1 and tool3, got %v", toolNames) } - - // Original should still have all 3 (immutability test) - allTools := tsg.AvailableTools(context.Background()) - if len(allTools) != 3 { - t.Fatalf("Original was mutated! Expected 3 tools, got %d", len(allTools)) - } } func TestWithToolsetsTrimsWhitespace(t *testing.T) { @@ -184,10 +176,8 @@ func TestWithToolsetsTrimsWhitespace(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // Whitespace should be trimmed - filteredReg := tsg.WithToolsets([]string{" toolset1 ", " toolset2 "}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{" toolset1 ", " toolset2 "}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { @@ -200,10 +190,8 @@ func TestWithToolsetsDeduplicates(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) - // Duplicates should be removed - filteredReg := tsg.WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 1 { @@ -216,10 +204,8 @@ func TestWithToolsetsIgnoresEmptyStrings(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) - // Empty strings should be ignored - filteredReg := tsg.WithToolsets([]string{"", "toolset1", " ", ""}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"", "toolset1", " ", ""}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 1 { @@ -233,8 +219,6 @@ func TestUnrecognizedToolsets(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - tests := []struct { name string input []string @@ -269,7 +253,7 @@ func TestUnrecognizedToolsets(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - filtered := tsg.WithToolsets(tt.input) + filtered := NewBuilder().SetTools(tools).WithToolsets(tt.input).Build() unrecognized := filtered.UnrecognizedToolsets() if len(unrecognized) != len(tt.expectedUnrecognized) { @@ -293,11 +277,9 @@ func TestWithTools(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // WithTools adds additional tools that bypass toolset filtering // When combined with WithToolsets([]), only the additional tools should be available - filteredReg := tsg.WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}) + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}).Build() filteredTools := filteredReg.AvailableTools(context.Background()) if len(filteredTools) != 2 { @@ -321,10 +303,8 @@ func TestChainedFilters(t *testing.T) { mockTool("write2", "toolset2", false), } - tsg := NewRegistry().SetTools(tools) - // Chain read-only and toolset filter - filtered := tsg.WithReadOnly(true).WithToolsets([]string{"toolset1"}) + filtered := NewBuilder().SetTools(tools).WithReadOnly(true).WithToolsets([]string{"toolset1"}).Build() result := filtered.AvailableTools(context.Background()) if len(result) != 1 { @@ -342,8 +322,8 @@ func TestToolsetIDs(t *testing.T) { mockTool("tool3", "toolset_b", true), // duplicate toolset } - tsg := NewRegistry().SetTools(tools) - ids := tsg.ToolsetIDs() + reg := NewBuilder().SetTools(tools).Build() + ids := reg.ToolsetIDs() if len(ids) != 2 { t.Fatalf("Expected 2 unique toolset IDs, got %d", len(ids)) @@ -361,8 +341,8 @@ func TestToolsetDescriptions(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - descriptions := tsg.ToolsetDescriptions() + reg := NewBuilder().SetTools(tools).Build() + descriptions := reg.ToolsetDescriptions() if len(descriptions) != 2 { t.Fatalf("Expected 2 descriptions, got %d", len(descriptions)) @@ -380,35 +360,31 @@ func TestToolsForToolset(t *testing.T) { mockTool("tool3", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - toolset1Tools := tsg.ToolsForToolset("toolset1") + reg := NewBuilder().SetTools(tools).Build() + toolset1Tools := reg.ToolsForToolset("toolset1") if len(toolset1Tools) != 2 { t.Fatalf("Expected 2 tools for toolset1, got %d", len(toolset1Tools)) } } -func TestWithDeprecatedToolAliases(t *testing.T) { +func TestWithDeprecatedAliases(t *testing.T) { tools := []ServerTool{ mockTool("new_name", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) - tsgWithAliases := tsg.WithDeprecatedToolAliases(map[string]string{ + reg := NewBuilder().SetTools(tools).WithDeprecatedAliases(map[string]string{ "old_name": "new_name", "get_issue": "issue_read", - }) - - // Original should be unchanged (immutable) - if len(tsg.deprecatedAliases) != 0 { - t.Errorf("original should have 0 aliases, got %d", len(tsg.deprecatedAliases)) - } + }).Build() - if len(tsgWithAliases.deprecatedAliases) != 2 { - t.Errorf("expected 2 aliases, got %d", len(tsgWithAliases.deprecatedAliases)) + // Test resolving aliases + resolved, aliasesUsed := reg.ResolveToolAliases([]string{"old_name"}) + if len(resolved) != 1 || resolved[0] != "new_name" { + t.Errorf("expected alias to resolve to 'new_name', got %v", resolved) } - if tsgWithAliases.deprecatedAliases["old_name"] != "new_name" { - t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsgWithAliases.deprecatedAliases["old_name"]) + if len(aliasesUsed) != 1 || aliasesUsed["old_name"] != "new_name" { + t.Errorf("expected alias mapping, got %v", aliasesUsed) } } @@ -418,14 +394,14 @@ func TestResolveToolAliases(t *testing.T) { mockTool("some_tool", "toolset1", true), } - tsg := NewRegistry().SetTools(tools). - WithDeprecatedToolAliases(map[string]string{ + reg := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ "get_issue": "issue_read", - }) + }).Build() // Test resolving a mix of aliases and canonical names input := []string{"get_issue", "some_tool"} - resolved, aliasesUsed := tsg.ResolveToolAliases(input) + resolved, aliasesUsed := reg.ResolveToolAliases(input) if len(resolved) != 2 { t.Fatalf("expected 2 resolved names, got %d", len(resolved)) @@ -450,10 +426,10 @@ func TestFindToolByName(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) + reg := NewBuilder().SetTools(tools).Build() // Find by name - tool, toolsetID, err := tsg.FindToolByName("issue_read") + tool, toolsetID, err := reg.FindToolByName("issue_read") if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -465,7 +441,7 @@ func TestFindToolByName(t *testing.T) { } // Non-existent tool - _, _, err = tsg.FindToolByName("nonexistent") + _, _, err = reg.FindToolByName("nonexistent") if err == nil { t.Error("expected error for non-existent tool") } @@ -478,11 +454,9 @@ func TestWithToolsAdditive(t *testing.T) { mockTool("repo_read", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // Test WithTools bypasses toolset filtering // Enable only toolset2, but add issue_read as additional tool - filtered := tsg.WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}) + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}).Build() available := filtered.AvailableTools(context.Background()) if len(available) != 2 { @@ -502,7 +476,7 @@ func TestWithToolsAdditive(t *testing.T) { } // Test WithTools respects read-only mode - readOnlyFiltered := tsg.WithReadOnly(true).WithTools([]string{"issue_write"}) + readOnlyFiltered := NewBuilder().SetTools(tools).WithReadOnly(true).WithTools([]string{"issue_write"}).Build() available = readOnlyFiltered.AvailableTools(context.Background()) // issue_write should be excluded because read-only applies to additional tools too @@ -513,7 +487,7 @@ func TestWithToolsAdditive(t *testing.T) { } // Test WithTools with non-existent tool (should not error, just won't match anything) - nonexistent := tsg.WithToolsets([]string{}).WithTools([]string{"nonexistent"}) + nonexistent := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"nonexistent"}).Build() available = nonexistent.AvailableTools(context.Background()) if len(available) != 0 { t.Errorf("expected 0 tools for non-existent additional tool, got %d", len(available)) @@ -525,13 +499,14 @@ func TestWithToolsResolvesAliases(t *testing.T) { mockTool("issue_read", "toolset1", true), } - tsg := NewRegistry().SetTools(tools). - WithDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - }) - // Using deprecated alias should resolve to canonical name - filtered := tsg.WithToolsets([]string{}).WithTools([]string{"get_issue"}) + filtered := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ + "get_issue": "issue_read", + }). + WithToolsets([]string{}). + WithTools([]string{"get_issue"}). + Build() available := filtered.AvailableTools(context.Background()) if len(available) != 1 { @@ -547,12 +522,12 @@ func TestHasToolset(t *testing.T) { mockTool("tool1", "toolset1", true), } - tsg := NewRegistry().SetTools(tools) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() - if !tsg.HasToolset("toolset1") { + if !reg.HasToolset("toolset1") { t.Error("expected HasToolset to return true for existing toolset") } - if tsg.HasToolset("nonexistent") { + if reg.HasToolset("nonexistent") { t.Error("expected HasToolset to return false for non-existent toolset") } } @@ -563,16 +538,15 @@ func TestEnabledToolsetIDs(t *testing.T) { mockTool("tool2", "toolset2", true), } - tsg := NewRegistry().SetTools(tools) - // Without filter, all toolsets are enabled - ids := tsg.EnabledToolsetIDs() + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + ids := reg.EnabledToolsetIDs() if len(ids) != 2 { t.Fatalf("Expected 2 enabled toolset IDs, got %d", len(ids)) } // With filter - filtered := tsg.WithToolsets([]string{"toolset1"}) + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1"}).Build() filteredIDs := filtered.EnabledToolsetIDs() if len(filteredIDs) != 1 { t.Fatalf("Expected 1 enabled toolset ID, got %d", len(filteredIDs)) @@ -588,18 +562,16 @@ func TestAllTools(t *testing.T) { mockTool("write_tool", "toolset1", false), } - tsg := NewRegistry().SetTools(tools) - // Even with read-only filter, AllTools returns everything - readOnlyTsg := tsg.WithReadOnly(true) + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() - allTools := readOnlyTsg.AllTools() + allTools := readOnlyReg.AllTools() if len(allTools) != 2 { t.Fatalf("Expected 2 tools from AllTools, got %d", len(allTools)) } // But AvailableTools respects the filter - availableTools := readOnlyTsg.AvailableTools(context.Background()) + availableTools := readOnlyReg.AvailableTools(context.Background()) if len(availableTools) != 1 { t.Fatalf("Expected 1 tool from AvailableTools, got %d", len(availableTools)) } @@ -656,8 +628,8 @@ func TestForMCPRequest_Initialize(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodInitialize, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodInitialize, "") // Initialize should return empty - capabilities come from ServerOptions if len(filtered.AvailableTools(context.Background())) != 0 { @@ -683,8 +655,8 @@ func TestForMCPRequest_ToolsList(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodToolsList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") // tools/list should return all tools, no resources or prompts if len(filtered.AvailableTools(context.Background())) != 2 { @@ -705,8 +677,8 @@ func TestForMCPRequest_ToolsCall(t *testing.T) { mockTool("list_repos", "repos", true), } - tsg := NewRegistry().SetTools(tools) - filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "get_me") + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "get_me") available := filtered.AvailableTools(context.Background()) if len(available) != 1 { @@ -722,8 +694,8 @@ func TestForMCPRequest_ToolsCall_NotFound(t *testing.T) { mockTool("get_me", "context", true), } - tsg := NewRegistry().SetTools(tools) - filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") if len(filtered.AvailableTools(context.Background())) != 0 { t.Errorf("Expected 0 tools for nonexistent tool, got %d", len(filtered.AvailableTools(context.Background()))) @@ -736,13 +708,14 @@ func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { mockTool("list_commits", "repos", true), } - tsg := NewRegistry().SetTools(tools). - WithDeprecatedToolAliases(map[string]string{ + reg := NewBuilder().SetTools(tools). + WithToolsets([]string{"all"}). + WithDeprecatedAliases(map[string]string{ "old_get_me": "get_me", - }) + }).Build() // Request using the deprecated alias - filtered := tsg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") available := filtered.AvailableTools(context.Background()) if len(available) != 1 { @@ -758,9 +731,9 @@ func TestForMCPRequest_ToolsCall_RespectsFilters(t *testing.T) { mockTool("create_issue", "issues", false), // write tool } - tsg := NewRegistry().SetTools(tools) - // Apply read-only filter, then ForMCPRequest - filtered := tsg.WithReadOnly(true).ForMCPRequest(MCPMethodToolsCall, "create_issue") + // Apply read-only filter at build time, then ForMCPRequest + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "create_issue") // The tool exists in the filtered group, but AvailableTools respects read-only available := filtered.AvailableTools(context.Background()) @@ -781,8 +754,8 @@ func TestForMCPRequest_ResourcesList(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodResourcesList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesList, "") if len(filtered.AvailableTools(context.Background())) != 0 { t.Errorf("Expected 0 tools for resources/list, got %d", len(filtered.AvailableTools(context.Background()))) @@ -801,8 +774,8 @@ func TestForMCPRequest_ResourcesRead(t *testing.T) { mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), } - tsg := NewRegistry().SetResources(resources) - filtered := tsg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") available := filtered.AvailableResourceTemplates(context.Background()) if len(available) != 1 { @@ -825,8 +798,8 @@ func TestForMCPRequest_PromptsList(t *testing.T) { mockPrompt("prompt2", "issues"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodPromptsList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsList, "") if len(filtered.AvailableTools(context.Background())) != 0 { t.Errorf("Expected 0 tools for prompts/list, got %d", len(filtered.AvailableTools(context.Background()))) @@ -845,8 +818,8 @@ func TestForMCPRequest_PromptsGet(t *testing.T) { mockPrompt("prompt2", "issues"), } - tsg := NewRegistry().SetPrompts(prompts) - filtered := tsg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") available := filtered.AvailablePrompts(context.Background()) if len(available) != 1 { @@ -868,8 +841,8 @@ func TestForMCPRequest_UnknownMethod(t *testing.T) { mockPrompt("prompt1", "repos"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) - filtered := tsg.ForMCPRequest("unknown/method", "") + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest("unknown/method", "") // Unknown methods should return empty if len(filtered.AvailableTools(context.Background())) != 0 { @@ -883,7 +856,7 @@ func TestForMCPRequest_UnknownMethod(t *testing.T) { } } -func TestForMCPRequest_Immutability(t *testing.T) { +func TestForMCPRequest_DoesNotMutateOriginal(t *testing.T) { tools := []ServerTool{ mockTool("tool1", "repos", true), mockTool("tool2", "issues", true), @@ -895,7 +868,7 @@ func TestForMCPRequest_Immutability(t *testing.T) { mockPrompt("prompt1", "repos"), } - original := NewRegistry().SetTools(tools).SetResources(resources).SetPrompts(prompts) + original := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() filtered := original.ForMCPRequest(MCPMethodToolsCall, "tool1") // Original should be unchanged @@ -929,13 +902,12 @@ func TestForMCPRequest_ChainedWithOtherFilters(t *testing.T) { mockToolWithDefault("delete_repo", "repos", false, true), // default but write } - tsg := NewRegistry().SetTools(tools) - // Chain: default toolsets -> read-only -> specific method - filtered := tsg. + reg := NewBuilder().SetTools(tools). WithToolsets([]string{"default"}). WithReadOnly(true). - ForMCPRequest(MCPMethodToolsList, "") + Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") available := filtered.AvailableTools(context.Background()) @@ -972,8 +944,8 @@ func TestForMCPRequest_ResourcesTemplatesList(t *testing.T) { mockResource("res1", "repos", "repo://{owner}/{repo}"), } - tsg := NewRegistry().SetTools(tools).SetResources(resources) - filtered := tsg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") + reg := NewBuilder().SetTools(tools).SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") // Same behavior as resources/list if len(filtered.AvailableTools(context.Background())) != 0 { @@ -1021,10 +993,9 @@ func TestFeatureFlagEnable(t *testing.T) { mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), } - tsg := NewRegistry().SetTools(tools) - // Without feature checker, tool with FeatureFlagEnable should be excluded - available := tsg.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 1 { t.Fatalf("Expected 1 tool without feature checker, got %d", len(available)) } @@ -1034,8 +1005,8 @@ func TestFeatureFlagEnable(t *testing.T) { // With feature checker returning false, tool should still be excluded checkerFalse := func(_ context.Context, _ string) (bool, error) { return false, nil } - filteredFalse := tsg.WithFeatureChecker(checkerFalse) - availableFalse := filteredFalse.AvailableTools(context.Background()) + regFalse := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerFalse).Build() + availableFalse := regFalse.AvailableTools(context.Background()) if len(availableFalse) != 1 { t.Fatalf("Expected 1 tool with false checker, got %d", len(availableFalse)) } @@ -1044,8 +1015,8 @@ func TestFeatureFlagEnable(t *testing.T) { checkerTrue := func(_ context.Context, flag string) (bool, error) { return flag == "my_feature", nil } - filteredTrue := tsg.WithFeatureChecker(checkerTrue) - availableTrue := filteredTrue.AvailableTools(context.Background()) + regTrue := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableTrue := regTrue.AvailableTools(context.Background()) if len(availableTrue) != 2 { t.Fatalf("Expected 2 tools with true checker, got %d", len(availableTrue)) } @@ -1057,10 +1028,9 @@ func TestFeatureFlagDisable(t *testing.T) { mockToolWithFlags("disabled_by_flag", "toolset1", true, "", "kill_switch"), } - tsg := NewRegistry().SetTools(tools) - // Without feature checker, tool with FeatureFlagDisable should be included (flag is false) - available := tsg.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 2 { t.Fatalf("Expected 2 tools without feature checker, got %d", len(available)) } @@ -1069,8 +1039,8 @@ func TestFeatureFlagDisable(t *testing.T) { checkerTrue := func(_ context.Context, flag string) (bool, error) { return flag == "kill_switch", nil } - filtered := tsg.WithFeatureChecker(checkerTrue) - availableFiltered := filtered.AvailableTools(context.Background()) + regFiltered := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableFiltered := regFiltered.AvailableTools(context.Background()) if len(availableFiltered) != 1 { t.Fatalf("Expected 1 tool with kill_switch enabled, got %d", len(availableFiltered)) } @@ -1085,23 +1055,24 @@ func TestFeatureFlagBoth(t *testing.T) { mockToolWithFlags("complex_tool", "toolset1", true, "new_feature", "kill_switch"), } - tsg := NewRegistry().SetTools(tools) - // Enable flag not set -> excluded checker1 := func(_ context.Context, _ string) (bool, error) { return false, nil } - if len(tsg.WithFeatureChecker(checker1).AvailableTools(context.Background())) != 0 { + reg1 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker1).Build() + if len(reg1.AvailableTools(context.Background())) != 0 { t.Error("Tool should be excluded when enable flag is false") } // Enable flag set, disable flag not set -> included checker2 := func(_ context.Context, flag string) (bool, error) { return flag == "new_feature", nil } - if len(tsg.WithFeatureChecker(checker2).AvailableTools(context.Background())) != 1 { + reg2 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker2).Build() + if len(reg2.AvailableTools(context.Background())) != 1 { t.Error("Tool should be included when enable flag is true and disable flag is false") } // Enable flag set, disable flag also set -> excluded (disable wins) checker3 := func(_ context.Context, _ string) (bool, error) { return true, nil } - if len(tsg.WithFeatureChecker(checker3).AvailableTools(context.Background())) != 0 { + reg3 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker3).Build() + if len(reg3.AvailableTools(context.Background())) != 0 { t.Error("Tool should be excluded when both flags are true (disable wins)") } } @@ -1111,14 +1082,12 @@ func TestFeatureFlagError(t *testing.T) { mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), } - tsg := NewRegistry().SetTools(tools) - // Checker that returns error should treat as false (tool excluded) checkerError := func(_ context.Context, _ string) (bool, error) { return false, fmt.Errorf("simulated error") } - filtered := tsg.WithFeatureChecker(checkerError) - available := filtered.AvailableTools(context.Background()) + reg := NewBuilder().SetTools(tools).WithFeatureChecker(checkerError).Build() + available := reg.AvailableTools(context.Background()) if len(available) != 0 { t.Errorf("Expected 0 tools when checker errors, got %d", len(available)) } @@ -1134,19 +1103,18 @@ func TestFeatureFlagResources(t *testing.T) { }, } - tsg := NewRegistry().SetResources(resources) - // Without checker, resource with enable flag should be excluded - available := tsg.AvailableResourceTemplates(context.Background()) + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + available := reg.AvailableResourceTemplates(context.Background()) if len(available) != 1 { t.Fatalf("Expected 1 resource without checker, got %d", len(available)) } // With checker returning true, both should be included checker := func(_ context.Context, _ string) (bool, error) { return true, nil } - filtered := tsg.WithFeatureChecker(checker) - if len(filtered.AvailableResourceTemplates(context.Background())) != 2 { - t.Errorf("Expected 2 resources with checker, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + regWithChecker := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources with checker, got %d", len(regWithChecker.AvailableResourceTemplates(context.Background()))) } } @@ -1160,19 +1128,18 @@ func TestFeatureFlagPrompts(t *testing.T) { }, } - tsg := NewRegistry().SetPrompts(prompts) - // Without checker, prompt with enable flag should be excluded - available := tsg.AvailablePrompts(context.Background()) + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + available := reg.AvailablePrompts(context.Background()) if len(available) != 1 { t.Fatalf("Expected 1 prompt without checker, got %d", len(available)) } // With checker returning true, both should be included checker := func(_ context.Context, _ string) (bool, error) { return true, nil } - filtered := tsg.WithFeatureChecker(checker) - if len(filtered.AvailablePrompts(context.Background())) != 2 { - t.Errorf("Expected 2 prompts with checker, got %d", len(filtered.AvailablePrompts(context.Background()))) + regWithChecker := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts with checker, got %d", len(regWithChecker.AvailablePrompts(context.Background()))) } } @@ -1207,61 +1174,3 @@ func TestServerToolHandlerPanicOnNil(t *testing.T) { tool.Handler(nil) } - -// TestRegistryCopyCopiesAllFields ensures the copy() method stays in sync with the struct. -// If you add a new field to Registry, this test will fail until you update copy(). -func TestRegistryCopyCopiesAllFields(t *testing.T) { - // Create a Registry with non-zero/non-nil values for ALL fields - original := &Registry{ - tools: []ServerTool{mockTool("t1", "ts1", true)}, - resourceTemplates: []ServerResourceTemplate{{Template: mcp.ResourceTemplate{Name: "r1"}}}, - prompts: []ServerPrompt{{Prompt: mcp.Prompt{Name: "p1"}}}, - deprecatedAliases: map[string]string{"old": "new"}, - readOnly: true, - enabledToolsets: map[ToolsetID]bool{"ts1": true}, - additionalTools: map[string]bool{"extra": true}, - featureChecker: func(_ context.Context, _ string) (bool, error) { return true, nil }, - unrecognizedToolsets: []string{"unknown"}, - } - - copied := original.copy() - - // Verify all fields are copied correctly - if len(copied.tools) != len(original.tools) || (len(copied.tools) > 0 && copied.tools[0].Tool.Name != original.tools[0].Tool.Name) { - t.Error("tools not copied correctly") - } - if len(copied.resourceTemplates) != len(original.resourceTemplates) { - t.Error("resourceTemplates not copied correctly") - } - if len(copied.prompts) != len(original.prompts) { - t.Error("prompts not copied correctly") - } - if len(copied.deprecatedAliases) != len(original.deprecatedAliases) || copied.deprecatedAliases["old"] != "new" { - t.Error("deprecatedAliases not copied correctly") - } - if copied.readOnly != original.readOnly { - t.Error("readOnly not copied correctly") - } - if len(copied.enabledToolsets) != len(original.enabledToolsets) || !copied.enabledToolsets["ts1"] { - t.Error("enabledToolsets not copied correctly") - } - if len(copied.additionalTools) != len(original.additionalTools) || !copied.additionalTools["extra"] { - t.Error("additionalTools not copied correctly") - } - if copied.featureChecker == nil { - t.Error("featureChecker not copied correctly") - } - if len(copied.unrecognizedToolsets) != len(original.unrecognizedToolsets) || copied.unrecognizedToolsets[0] != "unknown" { - t.Error("unrecognizedToolsets not copied correctly") - } - - // Verify maps are deep copied (mutations don't affect original) - copied.enabledToolsets["ts2"] = true - if original.enabledToolsets["ts2"] { - t.Error("enabledToolsets should be deep copied, not shared") - } - copied.additionalTools["another"] = true - if original.additionalTools["another"] { - t.Error("additionalTools should be deep copied, not shared") - } -} diff --git a/pkg/registry/resources.go b/pkg/registry/resources.go new file mode 100644 index 000000000..99e0240c5 --- /dev/null +++ b/pkg/registry/resources.go @@ -0,0 +1,48 @@ +package registry + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. +// This allows resources to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +type ResourceHandlerFunc func(deps any) mcp.ResourceHandler + +// ServerResourceTemplate pairs a resource template with its toolset metadata. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + // HandlerFunc generates the handler when given dependencies. + // This allows resources to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc ResourceHandlerFunc + // Toolset identifies which toolset this resource belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this resource + // to be available. If set and the flag is not enabled, the resource is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource + // to be omitted. Used to disable resources when a feature flag is on. + FeatureFlagDisable string +} + +// HasHandler returns true if this resource has a handler function. +func (sr *ServerResourceTemplate) HasHandler() bool { + return sr.HandlerFunc != nil +} + +// Handler returns a resource handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all resources should have handlers. +func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { + if sr.HandlerFunc == nil { + panic("HandlerFunc is nil for resource: " + sr.Template.Name) + } + return sr.HandlerFunc(deps) +} + +// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. +func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { + return ServerResourceTemplate{ + Template: resourceTemplate, + HandlerFunc: handlerFn, + Toolset: toolset, + } +} diff --git a/pkg/toolsets/server_tool.go b/pkg/registry/server_tool.go similarity index 99% rename from pkg/toolsets/server_tool.go rename to pkg/registry/server_tool.go index eb30f01f4..af00a2bed 100644 --- a/pkg/toolsets/server_tool.go +++ b/pkg/registry/server_tool.go @@ -1,4 +1,4 @@ -package toolsets +package registry import ( "context" diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go deleted file mode 100644 index 34e5fa923..000000000 --- a/pkg/toolsets/toolsets.go +++ /dev/null @@ -1,866 +0,0 @@ -package toolsets - -import ( - "context" - "fmt" - "os" - "slices" - "sort" - "strings" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type ToolsetDoesNotExistError struct { - Name string -} - -func (e *ToolsetDoesNotExistError) Error() string { - return fmt.Sprintf("toolset %s does not exist", e.Name) -} - -func (e *ToolsetDoesNotExistError) Is(target error) bool { - if target == nil { - return false - } - if _, ok := target.(*ToolsetDoesNotExistError); ok { - return true - } - return false -} - -func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { - return &ToolsetDoesNotExistError{Name: name} -} - -// ToolDoesNotExistError is returned when a tool is not found. -type ToolDoesNotExistError struct { - Name string -} - -func (e *ToolDoesNotExistError) Error() string { - return fmt.Sprintf("tool %s does not exist", e.Name) -} - -// NewToolDoesNotExistError creates a new ToolDoesNotExistError. -func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { - return &ToolDoesNotExistError{Name: name} -} - -// ServerTool is defined in server_tool.go - -// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. -// This allows resources to be defined statically while their handlers are generated -// on-demand with the appropriate dependencies. -type ResourceHandlerFunc func(deps any) mcp.ResourceHandler - -// ServerResourceTemplate pairs a resource template with its toolset metadata. -type ServerResourceTemplate struct { - Template mcp.ResourceTemplate - // HandlerFunc generates the handler when given dependencies. - // This allows resources to be passed around without handlers being set up, - // and handlers are only created when needed. - HandlerFunc ResourceHandlerFunc - // Toolset identifies which toolset this resource belongs to - Toolset ToolsetMetadata - // FeatureFlagEnable specifies a feature flag that must be enabled for this resource - // to be available. If set and the flag is not enabled, the resource is omitted. - FeatureFlagEnable string - // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource - // to be omitted. Used to disable resources when a feature flag is on. - FeatureFlagDisable string -} - -// HasHandler returns true if this resource has a handler function. -func (sr *ServerResourceTemplate) HasHandler() bool { - return sr.HandlerFunc != nil -} - -// Handler returns a resource handler by calling HandlerFunc with the given dependencies. -// Panics if HandlerFunc is nil - all resources should have handlers. -func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { - if sr.HandlerFunc == nil { - panic("HandlerFunc is nil for resource: " + sr.Template.Name) - } - return sr.HandlerFunc(deps) -} - -// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. -func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { - return ServerResourceTemplate{ - Template: resourceTemplate, - HandlerFunc: handlerFn, - Toolset: toolset, - } -} - -// ServerPrompt pairs a prompt with its toolset metadata. -type ServerPrompt struct { - Prompt mcp.Prompt - Handler mcp.PromptHandler - // Toolset identifies which toolset this prompt belongs to - Toolset ToolsetMetadata - // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt - // to be available. If set and the flag is not enabled, the prompt is omitted. - FeatureFlagEnable string - // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt - // to be omitted. Used to disable prompts when a feature flag is on. - FeatureFlagDisable string -} - -// NewServerPrompt creates a new ServerPrompt with toolset metadata. -func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { - return ServerPrompt{ - Prompt: prompt, - Handler: handler, - Toolset: toolset, - } -} - -// Registry holds a collection of tools, resources, and prompts. -// It supports immutable filtering operations that return new Registrys -// without modifying the original. This design allows for: -// - Building a full set of tools/resources/prompts once -// - Applying filters (read-only, feature flags, enabled toolsets) without mutation -// - Deterministic ordering for documentation generation -// - Lazy dependency injection only when registering with a server -type Registry struct { - // tools holds all tools in this group - tools []ServerTool - // resourceTemplates holds all resource templates in this group - resourceTemplates []ServerResourceTemplate - // prompts holds all prompts in this group - prompts []ServerPrompt - // deprecatedAliases maps old tool names to new canonical names - deprecatedAliases map[string]string - - // Filters - these control what's returned by Available* methods - // readOnly when true filters out write tools - readOnly bool - // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets - // when nil, all toolsets are enabled - enabledToolsets map[ToolsetID]bool - // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) - // These are additive - a tool is included if it matches toolset filters OR is in this set - additionalTools map[string]bool - // featureChecker when non-nil, checks if a feature flag is enabled. - // Takes context and flag name, returns (enabled, error). If error, log and treat as false. - // If checker is nil, all flag checks return false. - featureChecker FeatureFlagChecker - // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets - unrecognizedToolsets []string -} - -// FeatureFlagChecker is a function that checks if a feature flag is enabled. -// The context can be used to extract actor/user information for flag evaluation. -// Returns (enabled, error). If error occurs, the caller should log and treat as false. -type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) - -// NewRegistry creates a new empty Registry. -// Use SetTools, SetResources, SetPrompts to populate it. -func NewRegistry() *Registry { - return &Registry{ - deprecatedAliases: make(map[string]string), - } -} - -// SetTools sets the tools for this group. Returns self for chaining. -func (r *Registry) SetTools(tools []ServerTool) *Registry { - r.tools = tools - return r -} - -// SetResources sets the resource templates for this group. Returns self for chaining. -func (r *Registry) SetResources(resources []ServerResourceTemplate) *Registry { - r.resourceTemplates = resources - return r -} - -// SetPrompts sets the prompts for this group. Returns self for chaining. -func (r *Registry) SetPrompts(prompts []ServerPrompt) *Registry { - r.prompts = prompts - return r -} - -// copy creates a shallow copy of the Registry for immutable operations. -func (r *Registry) copy() *Registry { - newTG := &Registry{ - tools: r.tools, // slices are shared (immutable) - resourceTemplates: r.resourceTemplates, - prompts: r.prompts, - deprecatedAliases: r.deprecatedAliases, - readOnly: r.readOnly, - featureChecker: r.featureChecker, - } - - // Copy maps if they exist - if r.enabledToolsets != nil { - newTG.enabledToolsets = make(map[ToolsetID]bool, len(r.enabledToolsets)) - for k, v := range r.enabledToolsets { - newTG.enabledToolsets[k] = v - } - } - if r.additionalTools != nil { - newTG.additionalTools = make(map[string]bool, len(r.additionalTools)) - for k, v := range r.additionalTools { - newTG.additionalTools[k] = v - } - } - newTG.unrecognizedToolsets = r.unrecognizedToolsets - - return newTG -} - -// WithReadOnly returns a new Registry with read-only mode set. -// When true, write tools are filtered out from Available* methods. -func (r *Registry) WithReadOnly(readOnly bool) *Registry { - newTG := r.copy() - newTG.readOnly = readOnly - return newTG -} - -// WithToolsets returns a new Registry that only includes items from the specified toolsets. -// Special keywords: -// - "all": enables all toolsets -// - "default": expands to toolsets marked with Default: true in their metadata -// -// Input strings are trimmed of whitespace and duplicates are removed. -// Toolset IDs that don't match any registered toolsets are tracked and can be -// retrieved via UnrecognizedToolsets() for warning purposes. -// -// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets -// (useful for dynamic toolsets mode where tools are enabled on demand). -func (r *Registry) WithToolsets(toolsetIDs []string) *Registry { - newTG := r.copy() - newTG.unrecognizedToolsets = nil // reset for fresh calculation - - // Build a set of valid toolset IDs for validation - validIDs := make(map[ToolsetID]bool) - for _, t := range r.tools { - validIDs[t.Toolset.ID] = true - } - for _, r := range r.resourceTemplates { - validIDs[r.Toolset.ID] = true - } - for _, p := range r.prompts { - validIDs[p.Toolset.ID] = true - } - - // Check for "all" keyword - enables all toolsets - for _, id := range toolsetIDs { - if strings.TrimSpace(id) == "all" { - newTG.enabledToolsets = nil - return newTG - } - } - - // nil means use defaults, empty slice means no toolsets - if toolsetIDs == nil { - toolsetIDs = []string{"default"} - } - - // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized - seen := make(map[ToolsetID]bool) - expanded := make([]ToolsetID, 0, len(toolsetIDs)) - var unrecognized []string - - for _, id := range toolsetIDs { - trimmed := strings.TrimSpace(id) - if trimmed == "" { - continue - } - if trimmed == "default" { - for _, defaultID := range r.DefaultToolsetIDs() { - if !seen[defaultID] { - seen[defaultID] = true - expanded = append(expanded, defaultID) - } - } - } else { - tsID := ToolsetID(trimmed) - if !seen[tsID] { - seen[tsID] = true - expanded = append(expanded, tsID) - // Track if this toolset doesn't exist - if !validIDs[tsID] { - unrecognized = append(unrecognized, trimmed) - } - } - } - } - - newTG.unrecognizedToolsets = unrecognized - - if len(expanded) == 0 { - newTG.enabledToolsets = make(map[ToolsetID]bool) - return newTG - } - - newTG.enabledToolsets = make(map[ToolsetID]bool, len(expanded)) - for _, id := range expanded { - newTG.enabledToolsets[id] = true - } - return newTG -} - -// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't -// match any registered toolsets. This is useful for warning users about typos. -func (r *Registry) UnrecognizedToolsets() []string { - return r.unrecognizedToolsets -} - -// WithTools returns a new Registry with additional tools that bypass toolset filtering. -// These tools are additive - they will be included even if their toolset is not enabled. -// Read-only filtering still applies to these tools. -// Deprecated tool aliases are automatically resolved to their canonical names. -// Pass nil or empty slice to clear additional tools. -func (r *Registry) WithTools(toolNames []string) *Registry { - newTG := r.copy() - if len(toolNames) == 0 { - newTG.additionalTools = nil - return newTG - } - newTG.additionalTools = make(map[string]bool, len(toolNames)) - for _, name := range toolNames { - // Resolve deprecated aliases to canonical names - if canonical, isAlias := r.deprecatedAliases[name]; isAlias { - newTG.additionalTools[canonical] = true - } else { - newTG.additionalTools[name] = true - } - } - return newTG -} - -// WithFeatureChecker returns a new Registry with a feature checker function. -// The checker receives a context (for actor extraction) and feature flag name, returns (enabled, error). -// If error occurs, it will be logged and treated as false. -// If checker is nil, all feature flag checks return false (items with FeatureFlagEnable are excluded, -// items with FeatureFlagDisable are included). -func (r *Registry) WithFeatureChecker(checker FeatureFlagChecker) *Registry { - newTG := r.copy() - newTG.featureChecker = checker - return newTG -} - -// MCP method constants for use with ForMCPRequest. -const ( - MCPMethodInitialize = "initialize" - MCPMethodToolsList = "tools/list" - MCPMethodToolsCall = "tools/call" - MCPMethodResourcesList = "resources/list" - MCPMethodResourcesRead = "resources/read" - MCPMethodResourcesTemplatesList = "resources/templates/list" - MCPMethodPromptsList = "prompts/list" - MCPMethodPromptsGet = "prompts/get" -) - -// ForMCPRequest returns a Registry optimized for a specific MCP request. -// This is designed for servers that create a new instance per request (like the remote server), -// allowing them to only register the items needed for that specific request rather than all ~90 tools. -// -// Parameters: -// - method: The MCP method being called (use MCP* constants) -// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) -// -// Returns a new Registry containing only the items relevant to the request: -// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) -// - MCPMethodToolsList: All available tools (no resources/prompts) -// - MCPMethodToolsCall: Only the named tool -// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) -// - MCPMethodResourcesRead: Only the named resource template -// - MCPMethodPromptsList: All available prompts (no tools/resources) -// - MCPMethodPromptsGet: Only the named prompt -// - Unknown methods: Empty (no items registered) -// -// All existing filters (read-only, toolsets, etc.) still apply to the returned items. -func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { - result := r.copy() - - // Helper to clear all item types - clearAll := func() { - result.tools = []ServerTool{} - result.resourceTemplates = []ServerResourceTemplate{} - result.prompts = []ServerPrompt{} - } - - switch method { - case MCPMethodInitialize: - clearAll() - case MCPMethodToolsList: - result.resourceTemplates, result.prompts = nil, nil - case MCPMethodToolsCall: - result.resourceTemplates, result.prompts = nil, nil - if itemName != "" { - result.tools = r.filterToolsByName(itemName) - } - case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: - result.tools, result.prompts = nil, nil - case MCPMethodResourcesRead: - result.tools, result.prompts = nil, nil - if itemName != "" { - result.resourceTemplates = r.filterResourcesByURI(itemName) - } - case MCPMethodPromptsList: - result.tools, result.resourceTemplates = nil, nil - case MCPMethodPromptsGet: - result.tools, result.resourceTemplates = nil, nil - if itemName != "" { - result.prompts = r.filterPromptsByName(itemName) - } - default: - clearAll() - } - - return result -} - -// filterToolsByName returns tools matching the given name, checking deprecated aliases. -// Returns from the current tools slice (respects existing filter chain). -func (r *Registry) filterToolsByName(name string) []ServerTool { - // First check for exact match - for i := range r.tools { - if r.tools[i].Tool.Name == name { - return []ServerTool{r.tools[i]} - } - } - // Check if name is a deprecated alias - if canonical, isAlias := r.deprecatedAliases[name]; isAlias { - for i := range r.tools { - if r.tools[i].Tool.Name == canonical { - return []ServerTool{r.tools[i]} - } - } - } - return []ServerTool{} -} - -// filterResourcesByURI returns resource templates matching the given URI pattern. -func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { - for i := range r.resourceTemplates { - // Check if URI matches the template pattern (exact match on URITemplate string) - if r.resourceTemplates[i].Template.URITemplate == uri { - return []ServerResourceTemplate{r.resourceTemplates[i]} - } - } - return []ServerResourceTemplate{} -} - -// filterPromptsByName returns prompts matching the given name. -func (r *Registry) filterPromptsByName(name string) []ServerPrompt { - for i := range r.prompts { - if r.prompts[i].Prompt.Name == name { - return []ServerPrompt{r.prompts[i]} - } - } - return []ServerPrompt{} -} - -// WithDeprecatedToolAliases returns a new Registry with the given deprecated aliases added. -// Aliases map old tool names to new canonical names. -func (r *Registry) WithDeprecatedToolAliases(aliases map[string]string) *Registry { - newTG := r.copy() - // Ensure we have a fresh map - newTG.deprecatedAliases = make(map[string]string, len(r.deprecatedAliases)+len(aliases)) - for k, v := range r.deprecatedAliases { - newTG.deprecatedAliases[k] = v - } - for oldName, newName := range aliases { - newTG.deprecatedAliases[oldName] = newName - } - return newTG -} - -// isToolsetEnabled checks if a toolset is enabled based on current filters. -func (r *Registry) isToolsetEnabled(toolsetID ToolsetID) bool { - // Check enabled toolsets filter - if r.enabledToolsets != nil { - return r.enabledToolsets[toolsetID] - } - return true -} - -// checkFeatureFlag checks a feature flag using the feature checker. -// Returns false if checker is nil or returns an error (errors are logged). -func (r *Registry) checkFeatureFlag(ctx context.Context, flagName string) bool { - if r.featureChecker == nil || flagName == "" { - return false - } - enabled, err := r.featureChecker(ctx, flagName) - if err != nil { - fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) - return false - } - return enabled -} - -// isFeatureFlagAllowed checks if an item passes feature flag filtering. -// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled -// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled -func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { - // Check enable flag - item requires this flag to be on - if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { - return false - } - // Check disable flag - item is excluded if this flag is on - if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { - return false - } - return true -} - -// isToolEnabled checks if a specific tool is enabled based on current filters. -func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { - // Check read-only filter first (applies to all tools) - if r.readOnly && !tool.IsReadOnly() { - return false - } - // Check feature flags - if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { - return false - } - // Check if tool is in additionalTools (bypasses toolset filter) - if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { - return true - } - // Check toolset filter - if !r.isToolsetEnabled(tool.Toolset.ID) { - return false - } - return true -} - -// AvailableTools returns the tools that pass all current filters, -// sorted deterministically by toolset ID, then tool name. -// The context is used for feature flag evaluation. -func (r *Registry) AvailableTools(ctx context.Context) []ServerTool { - var result []ServerTool - for i := range r.tools { - tool := &r.tools[i] - if r.isToolEnabled(ctx, tool) { - result = append(result, *tool) - } - } - - // Sort deterministically: by toolset ID, then by tool name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Tool.Name < result[j].Tool.Name - }) - - return result -} - -// AvailableResourceTemplates returns resource templates that pass all current filters, -// sorted deterministically by toolset ID, then template name. -// The context is used for feature flag evaluation. -func (r *Registry) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { - var result []ServerResourceTemplate - for i := range r.resourceTemplates { - res := &r.resourceTemplates[i] - // Check feature flags - if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { - continue - } - if r.isToolsetEnabled(res.Toolset.ID) { - result = append(result, *res) - } - } - - // Sort deterministically: by toolset ID, then by template name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Template.Name < result[j].Template.Name - }) - - return result -} - -// AvailablePrompts returns prompts that pass all current filters, -// sorted deterministically by toolset ID, then prompt name. -// The context is used for feature flag evaluation. -func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { - var result []ServerPrompt - for i := range r.prompts { - prompt := &r.prompts[i] - // Check feature flags - if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { - continue - } - if r.isToolsetEnabled(prompt.Toolset.ID) { - result = append(result, *prompt) - } - } - - // Sort deterministically: by toolset ID, then by prompt name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Prompt.Name < result[j].Prompt.Name - }) - - return result -} - -// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. -func (r *Registry) ToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range r.tools { - seen[r.tools[i].Toolset.ID] = true - } - for i := range r.resourceTemplates { - seen[r.resourceTemplates[i].Toolset.ID] = true - } - for i := range r.prompts { - seen[r.prompts[i].Toolset.ID] = true - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids -} - -// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. -// The IDs are returned in sorted order for deterministic output. -func (r *Registry) DefaultToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range r.tools { - if r.tools[i].Toolset.Default { - seen[r.tools[i].Toolset.ID] = true - } - } - for i := range r.resourceTemplates { - if r.resourceTemplates[i].Toolset.Default { - seen[r.resourceTemplates[i].Toolset.ID] = true - } - } - for i := range r.prompts { - if r.prompts[i].Toolset.Default { - seen[r.prompts[i].Toolset.ID] = true - } - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids -} - -// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. -func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { - descriptions := make(map[ToolsetID]string) - for i := range r.tools { - t := &r.tools[i] - if t.Toolset.Description != "" { - descriptions[t.Toolset.ID] = t.Toolset.Description - } - } - for i := range r.resourceTemplates { - r := &r.resourceTemplates[i] - if r.Toolset.Description != "" { - descriptions[r.Toolset.ID] = r.Toolset.Description - } - } - for i := range r.prompts { - p := &r.prompts[i] - if p.Toolset.Description != "" { - descriptions[p.Toolset.ID] = p.Toolset.Description - } - } - return descriptions -} - -// ToolsForToolset returns all tools belonging to a specific toolset. -// This method bypasses the toolset enabled filter (for dynamic toolset registration), -// but still respects the read-only filter. -func (r *Registry) ToolsForToolset(toolsetID ToolsetID) []ServerTool { - var result []ServerTool - for i := range r.tools { - tool := &r.tools[i] - // Only check read-only filter, not toolset enabled filter - if tool.Toolset.ID == toolsetID { - if r.readOnly && !tool.IsReadOnly() { - continue - } - result = append(result, *tool) - } - } - - // Sort by tool name for deterministic order - sort.Slice(result, func(i, j int) bool { - return result[i].Tool.Name < result[j].Tool.Name - }) - - return result -} - -// RegisterTools registers all available tools with the server using the provided dependencies. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { - for _, tool := range r.AvailableTools(ctx) { - tool.RegisterFunc(s, deps) - } -} - -// RegisterResourceTemplates registers all available resource templates with the server. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { - for _, res := range r.AvailableResourceTemplates(ctx) { - s.AddResourceTemplate(&res.Template, res.Handler(deps)) - } -} - -// RegisterPrompts registers all available prompts with the server. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterPrompts(ctx context.Context, s *mcp.Server) { - for _, prompt := range r.AvailablePrompts(ctx) { - s.AddPrompt(&prompt.Prompt, prompt.Handler) - } -} - -// RegisterAll registers all available tools, resources, and prompts with the server. -// The context is used for feature flag evaluation. -func (r *Registry) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { - r.RegisterTools(ctx, s, deps) - r.RegisterResourceTemplates(ctx, s, deps) - r.RegisterPrompts(ctx, s) -} - -// ResolveToolAliases resolves deprecated tool aliases to their canonical names. -// It logs a warning to stderr for each deprecated alias that is resolved. -// Returns: -// - resolved: tool names with aliases replaced by canonical names -// - aliasesUsed: map of oldName → newName for each alias that was resolved -func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { - resolved = make([]string, 0, len(toolNames)) - aliasesUsed = make(map[string]string) - for _, toolName := range toolNames { - if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { - fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) - aliasesUsed[toolName] = canonicalName - resolved = append(resolved, canonicalName) - } else { - resolved = append(resolved, toolName) - } - } - return resolved, aliasesUsed -} - -// FindToolByName searches all tools for one matching the given name. -// Returns the tool, its toolset ID, and an error if not found. -// This searches ALL tools regardless of filters. -func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { - for i := range r.tools { - tool := &r.tools[i] - if tool.Tool.Name == toolName { - return tool, tool.Toolset.ID, nil - } - } - return nil, "", NewToolDoesNotExistError(toolName) -} - -// HasToolset checks if any tool/resource/prompt belongs to the given toolset. -func (r *Registry) HasToolset(toolsetID ToolsetID) bool { - for i := range r.tools { - if r.tools[i].Toolset.ID == toolsetID { - return true - } - } - for i := range r.resourceTemplates { - if r.resourceTemplates[i].Toolset.ID == toolsetID { - return true - } - } - for i := range r.prompts { - if r.prompts[i].Toolset.ID == toolsetID { - return true - } - } - return false -} - -// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. -// Returns all toolset IDs if no filter is set. -func (r *Registry) EnabledToolsetIDs() []ToolsetID { - if r.enabledToolsets == nil { - return r.ToolsetIDs() - } - - ids := make([]ToolsetID, 0, len(r.enabledToolsets)) - for id := range r.enabledToolsets { - if r.HasToolset(id) { - ids = append(ids, id) - } - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids -} - -// IsToolsetEnabled checks if a toolset is currently enabled based on filters. -func (r *Registry) IsToolsetEnabled(toolsetID ToolsetID) bool { - return r.isToolsetEnabled(toolsetID) -} - -// EnableToolset marks a toolset as enabled in this group. -// This is used by dynamic toolset management to track which toolsets have been enabled. -func (r *Registry) EnableToolset(toolsetID ToolsetID) { - if r.enabledToolsets == nil { - // nil means all enabled, so nothing to do - return - } - r.enabledToolsets[toolsetID] = true -} - -// AllTools returns all tools without any filtering, sorted deterministically. -func (r *Registry) AllTools() []ServerTool { - result := slices.Clone(r.tools) - - // Sort deterministically: by toolset ID, then by tool name - sort.Slice(result, func(i, j int) bool { - if result[i].Toolset.ID != result[j].Toolset.ID { - return result[i].Toolset.ID < result[j].Toolset.ID - } - return result[i].Tool.Name < result[j].Tool.Name - }) - - return result -} - -// AvailableToolsets returns the unique toolsets that have tools, in sorted order. -// This is the ordered intersection of toolsets with reality - only toolsets that -// actually contain tools are returned, sorted by toolset ID. -// Optional exclude parameter filters out specific toolset IDs from the result. -func (r *Registry) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { - tools := r.AllTools() - if len(tools) == 0 { - return nil - } - - // Build exclude set for O(1) lookup - excludeSet := make(map[ToolsetID]bool, len(exclude)) - for _, id := range exclude { - excludeSet[id] = true - } - - var result []ToolsetMetadata - var lastID ToolsetID - for _, tool := range tools { - if tool.Toolset.ID != lastID { - lastID = tool.Toolset.ID - if !excludeSet[lastID] { - result = append(result, tool.Toolset) - } - } - } - return result -} From e943fd1a476d2e2bbb007c30d7cf004b62afbe7f Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 15:38:39 +0100 Subject: [PATCH 12/27] fix: remove unnecessary type arguments in helper_test.go --- pkg/github/helper_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9c55ba841..4fdf5f928 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -216,7 +216,7 @@ func TestOptionalParamOK(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Test with string type assertion if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { - val, ok, err := OptionalParamOK[string, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[string](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -231,7 +231,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with bool type assertion if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { - val, ok, err := OptionalParamOK[bool, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[bool](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -246,7 +246,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with float64 type assertion (for number case) if _, isFloat := tc.expectedVal.(float64); isFloat { - val, ok, err := OptionalParamOK[float64, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[float64](tc.args, tc.paramName) if tc.expectError { // This case shouldn't happen for float64 in the defined tests require.Fail(t, "Unexpected error case for float64") From a816ac104f9175e1f4bad8c8e09769214b48888b Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 16:20:15 +0100 Subject: [PATCH 13/27] fix: restore correct behavior for --tools and --toolsets flags Two behavioral regressions were fixed in resolveEnabledToolsets(): 1. When --tools=X is used without --toolsets, the server should only register the specified tools, not the default toolsets. Now returns an empty slice instead of nil when EnabledTools is set. 2. When --toolsets=all --dynamic-toolsets is used, the 'all' and 'default' pseudo-toolsets should be removed so only the dynamic management tools are registered. This matches the original pre-refactor behavior. --- internal/ghmcp/server.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 1f924e482..32684ca11 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -123,13 +123,26 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, // resolveEnabledToolsets determines which toolsets should be enabled based on config. // Returns nil for "use defaults", empty slice for "none", or explicit list. func resolveEnabledToolsets(cfg MCPServerConfig) []string { - if cfg.EnabledToolsets != nil { - return cfg.EnabledToolsets + enabledToolsets := cfg.EnabledToolsets + + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if cfg.DynamicToolsets && enabledToolsets != nil { + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) + } + + if enabledToolsets != nil { + return enabledToolsets } if cfg.DynamicToolsets { // Dynamic mode with no toolsets specified: start empty so users enable on demand return []string{} } + if len(cfg.EnabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} + } // nil means "use defaults" in WithToolsets return nil } From 8e803460a36760b9aca5a09e286d56217fd27168 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 16:49:12 +0100 Subject: [PATCH 14/27] Move labels tools to issues toolset Labels are closely related to issues - you add labels to issues, search issues by label, etc. Keeping them in a separate toolset required users to explicitly enable 'labels' to get this functionality. Moving to issues toolset makes labels available by default since issues is a default toolset. --- README.md | 43 ++++++++++++++++++------------------------- docs/remote-server.md | 1 - pkg/github/labels.go | 6 +++--- pkg/github/tools.go | 4 ---- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 117bacacd..9d9d7dc13 100644 --- a/README.md +++ b/README.md @@ -463,7 +463,6 @@ The following sets of tools are available: | `gists` | GitHub Gist related tools | | `git` | GitHub Git API related tools for low-level Git operations | | `issues` | GitHub Issues related tools | -| `labels` | GitHub Labels related tools | | `notifications` | GitHub Notifications related tools | | `orgs` | GitHub Organization related tools | | `projects` | GitHub Projects related tools | @@ -718,6 +717,11 @@ The following sets of tools are available: - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) +- **get_label** - Get a specific label from a repository. + - `name`: Label name. (string, required) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + - **issue_read** - Get issue details - `issue_number`: The number of the issue (number, required) - `method`: The read operation to perform on a single issue. @@ -751,6 +755,15 @@ The following sets of tools are available: - `title`: Issue title (string, optional) - `type`: Type of this issue. Only use if the repository has issue types configured. Use list_issue_types tool to get valid type values for the organization. If the repository doesn't support issue types, omit this parameter. (string, optional) +- **label_write** - Write operations on repository labels. + - `color`: Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'. (string, optional) + - `description`: Label description text. Optional for 'create' and 'update'. (string, optional) + - `method`: Operation to perform: 'create', 'update', or 'delete' (string, required) + - `name`: Label name - required for all operations (string, required) + - `new_name`: New name for the label (used only with 'update' method to rename) (string, optional) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + - **list_issue_types** - List available issue types - `owner`: The organization owner of the repository (string, required) @@ -765,6 +778,10 @@ The following sets of tools are available: - `since`: Filter by date (ISO 8601 timestamp) (string, optional) - `state`: Filter by state, by default both open and closed issues are returned when not provided (string, optional) +- **list_label** - List labels from a repository + - `owner`: Repository owner (username or organization name) - required for all operations (string, required) + - `repo`: Repository name - required for all operations (string, required) + - **search_issues** - Search issues - `order`: Sort order (string, optional) - `owner`: Optional repository owner. If provided with repo, only issues for this repository are listed. (string, optional) @@ -793,30 +810,6 @@ The following sets of tools are available:
-Labels - -- **get_label** - Get a specific label from a repository. - - `name`: Label name. (string, required) - - `owner`: Repository owner (username or organization name) (string, required) - - `repo`: Repository name (string, required) - -- **label_write** - Write operations on repository labels. - - `color`: Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'. (string, optional) - - `description`: Label description text. Optional for 'create' and 'update'. (string, optional) - - `method`: Operation to perform: 'create', 'update', or 'delete' (string, required) - - `name`: Label name - required for all operations (string, required) - - `new_name`: New name for the label (used only with 'update' method to rename) (string, optional) - - `owner`: Repository owner (username or organization name) (string, required) - - `repo`: Repository name (string, required) - -- **list_label** - List labels from a repository - - `owner`: Repository owner (username or organization name) - required for all operations (string, required) - - `repo`: Repository name - required for all operations (string, required) - -
- -
- Notifications - **dismiss_notification** - Dismiss notification diff --git a/docs/remote-server.md b/docs/remote-server.md index 53fe36127..493a49255 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -27,7 +27,6 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | | Git | GitHub Git API related tools for low-level Git operations | https://api.githubcopilot.com/mcp/x/git | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/git/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%2Freadonly%22%7D) | | Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | -| Labels | GitHub Labels related tools | https://api.githubcopilot.com/mcp/x/labels | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/labels/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%2Freadonly%22%7D) | | Notifications | GitHub Notifications related tools | https://api.githubcopilot.com/mcp/x/notifications | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/notifications/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%2Freadonly%22%7D) | | Organizations | GitHub Organization related tools | https://api.githubcopilot.com/mcp/x/orgs | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/orgs/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%2Freadonly%22%7D) | | Projects | GitHub Projects related tools | https://api.githubcopilot.com/mcp/x/projects | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/projects/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%2Freadonly%22%7D) | diff --git a/pkg/github/labels.go b/pkg/github/labels.go index 90bf8066b..d6f0f92b9 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -18,7 +18,7 @@ import ( // GetLabel retrieves a specific label by name from a GitHub repository func GetLabel(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( - ToolsetLabels, + ToolsetMetadataIssues, mcp.Tool{ Name: "get_label", Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), @@ -113,7 +113,7 @@ func GetLabel(t translations.TranslationHelperFunc) registry.ServerTool { // ListLabels lists labels from a repository func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( - ToolsetLabels, + ToolsetMetadataIssues, mcp.Tool{ Name: "list_label", Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), @@ -205,7 +205,7 @@ func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { // LabelWrite handles create, update, and delete operations for GitHub labels func LabelWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( - ToolsetLabels, + ToolsetMetadataIssues, mcp.Tool{ Name: "label_write", Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 714397b5d..1b82169a2 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -106,10 +106,6 @@ var ( ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", } - ToolsetLabels = registry.ToolsetMetadata{ - ID: "labels", - Description: "GitHub Labels related tools", - } ) // AllTools returns all tools with their embedded toolset metadata. From 855a489240a7e2d7f08a13b89a7a010344567f14 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 17:04:34 +0100 Subject: [PATCH 15/27] Restore labels toolset with get_label in both issues and labels This restores conformance with the original behavior where: - get_label is in issues toolset (read-only label access for issue workflows) - get_label, list_label, label_write are in labels toolset (full management) The duplicate get_label registration is intentional - it was in both toolsets in the original implementation. Added test exception to allow this case. --- README.md | 38 +++++++++++++++++++---------- docs/remote-server.md | 1 + pkg/github/labels.go | 12 +++++++-- pkg/github/tools.go | 5 ++++ pkg/github/tools_validation_test.go | 12 +++++++-- 5 files changed, 51 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 9d9d7dc13..e8737cd25 100644 --- a/README.md +++ b/README.md @@ -463,6 +463,7 @@ The following sets of tools are available: | `gists` | GitHub Gist related tools | | `git` | GitHub Git API related tools for low-level Git operations | | `issues` | GitHub Issues related tools | +| `labels` | GitHub Labels related tools | | `notifications` | GitHub Notifications related tools | | `orgs` | GitHub Organization related tools | | `projects` | GitHub Projects related tools | @@ -755,15 +756,6 @@ The following sets of tools are available: - `title`: Issue title (string, optional) - `type`: Type of this issue. Only use if the repository has issue types configured. Use list_issue_types tool to get valid type values for the organization. If the repository doesn't support issue types, omit this parameter. (string, optional) -- **label_write** - Write operations on repository labels. - - `color`: Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'. (string, optional) - - `description`: Label description text. Optional for 'create' and 'update'. (string, optional) - - `method`: Operation to perform: 'create', 'update', or 'delete' (string, required) - - `name`: Label name - required for all operations (string, required) - - `new_name`: New name for the label (used only with 'update' method to rename) (string, optional) - - `owner`: Repository owner (username or organization name) (string, required) - - `repo`: Repository name (string, required) - - **list_issue_types** - List available issue types - `owner`: The organization owner of the repository (string, required) @@ -778,10 +770,6 @@ The following sets of tools are available: - `since`: Filter by date (ISO 8601 timestamp) (string, optional) - `state`: Filter by state, by default both open and closed issues are returned when not provided (string, optional) -- **list_label** - List labels from a repository - - `owner`: Repository owner (username or organization name) - required for all operations (string, required) - - `repo`: Repository name - required for all operations (string, required) - - **search_issues** - Search issues - `order`: Sort order (string, optional) - `owner`: Optional repository owner. If provided with repo, only issues for this repository are listed. (string, optional) @@ -810,6 +798,30 @@ The following sets of tools are available:
+Labels + +- **get_label** - Get a specific label from a repository. + - `name`: Label name. (string, required) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + +- **label_write** - Write operations on repository labels. + - `color`: Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'. (string, optional) + - `description`: Label description text. Optional for 'create' and 'update'. (string, optional) + - `method`: Operation to perform: 'create', 'update', or 'delete' (string, required) + - `name`: Label name - required for all operations (string, required) + - `new_name`: New name for the label (used only with 'update' method to rename) (string, optional) + - `owner`: Repository owner (username or organization name) (string, required) + - `repo`: Repository name (string, required) + +- **list_label** - List labels from a repository + - `owner`: Repository owner (username or organization name) - required for all operations (string, required) + - `repo`: Repository name - required for all operations (string, required) + +
+ +
+ Notifications - **dismiss_notification** - Dismiss notification diff --git a/docs/remote-server.md b/docs/remote-server.md index 493a49255..53fe36127 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -27,6 +27,7 @@ Below is a table of available toolsets for the remote GitHub MCP Server. Each to | Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | | Git | GitHub Git API related tools for low-level Git operations | https://api.githubcopilot.com/mcp/x/git | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/git/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%2Freadonly%22%7D) | | Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | +| Labels | GitHub Labels related tools | https://api.githubcopilot.com/mcp/x/labels | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/labels/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%2Freadonly%22%7D) | | Notifications | GitHub Notifications related tools | https://api.githubcopilot.com/mcp/x/notifications | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/notifications/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%2Freadonly%22%7D) | | Organizations | GitHub Organization related tools | https://api.githubcopilot.com/mcp/x/orgs | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/orgs/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%2Freadonly%22%7D) | | Projects | GitHub Projects related tools | https://api.githubcopilot.com/mcp/x/projects | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/projects/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%2Freadonly%22%7D) | diff --git a/pkg/github/labels.go b/pkg/github/labels.go index d6f0f92b9..6088aaa8e 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -110,10 +110,18 @@ func GetLabel(t translations.TranslationHelperFunc) registry.ServerTool { ) } +// GetLabelForLabelsToolset returns the same GetLabel tool but registered in the labels toolset. +// This provides conformance with the original behavior where get_label was in both toolsets. +func GetLabelForLabelsToolset(t translations.TranslationHelperFunc) registry.ServerTool { + tool := GetLabel(t) + tool.Toolset = ToolsetLabels + return tool +} + // ListLabels lists labels from a repository func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( - ToolsetMetadataIssues, + ToolsetLabels, mcp.Tool{ Name: "list_label", Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), @@ -205,7 +213,7 @@ func ListLabels(t translations.TranslationHelperFunc) registry.ServerTool { // LabelWrite handles create, update, and delete operations for GitHub labels func LabelWrite(t translations.TranslationHelperFunc) registry.ServerTool { return NewTool( - ToolsetMetadataIssues, + ToolsetLabels, mcp.Tool{ Name: "label_write", Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 1b82169a2..0d540f14e 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -106,6 +106,10 @@ var ( ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", } + ToolsetLabels = registry.ToolsetMetadata{ + ID: "labels", + Description: "GitHub Labels related tools", + } ) // AllTools returns all tools with their embedded toolset metadata. @@ -237,6 +241,7 @@ func AllTools(t translations.TranslationHelperFunc) []registry.ServerTool { // Label tools GetLabel(t), + GetLabelForLabelsToolset(t), ListLabels(t), LabelWrite(t), } diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go index 1e74c2518..aa809dfa6 100644 --- a/pkg/github/tools_validation_test.go +++ b/pkg/github/tools_validation_test.go @@ -102,10 +102,18 @@ func TestNoDuplicateToolNames(t *testing.T) { tools := AllTools(stubTranslation) seen := make(map[string]bool) + // get_label is intentionally in both issues and labels toolsets for conformance + // with original behavior where it was registered in both + allowedDuplicates := map[string]bool{ + "get_label": true, + } + for _, tool := range tools { name := tool.Tool.Name - assert.False(t, seen[name], - "Duplicate tool name found: %q", name) + if !allowedDuplicates[name] { + assert.False(t, seen[name], + "Duplicate tool name found: %q", name) + } seen[name] = true } } From 6364d8652edb721024fdf9e66ebc9774a73f89b0 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:05:25 +0100 Subject: [PATCH 16/27] Fix instruction generation and capability advertisement - Expand nil toolsets to default IDs before GenerateInstructions (nil means 'use defaults' in registry but instructions need actual names) - Remove unconditional HasTools/HasResources/HasPrompts=true in NewServer (let SDK determine capabilities based on registered items, matching main) --- internal/ghmcp/server.go | 9 +- pkg/github/server.go | 7 - pkg/github/tools.go | 12 ++ script/conformance-test | 313 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 333 insertions(+), 8 deletions(-) create mode 100755 script/conformance-test diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 32684ca11..973698743 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -160,9 +160,16 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { enabledToolsets := resolveEnabledToolsets(cfg) + // For instruction generation, we need actual toolset names (not nil). + // nil means "use defaults" in registry, so expand it for instructions. + instructionToolsets := enabledToolsets + if instructionToolsets == nil { + instructionToolsets = github.GetDefaultToolsetIDs() + } + // Create the MCP server ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{ - Instructions: github.GenerateInstructions(enabledToolsets), + Instructions: github.GenerateInstructions(instructionToolsets), Logger: cfg.Logger, CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { return clients.rest, nil diff --git a/pkg/github/server.go b/pkg/github/server.go index 7432466d1..a9c9305a2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -21,13 +21,6 @@ func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { opts = &mcp.ServerOptions{} } - // Always advertise capabilities so clients know we support list_changed notifications. - // This is important for dynamic toolsets mode where we start with few tools - // and add more at runtime. - opts.HasTools = true - opts.HasResources = true - opts.HasPrompts = true - // Create a new MCP server s := mcp.NewServer(&mcp.Implementation{ Name: "github-mcp-server", diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 0d540f14e..bc06f793f 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -382,3 +382,15 @@ func CleanTools(toolNames []string) []string { return result } + +// GetDefaultToolsetIDs returns the IDs of toolsets marked as Default. +// This is a convenience function that builds a registry to determine defaults. +func GetDefaultToolsetIDs() []string { + r := NewRegistry(stubTranslator).Build() + ids := r.DefaultToolsetIDs() + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} diff --git a/script/conformance-test b/script/conformance-test new file mode 100755 index 000000000..fec4cc573 --- /dev/null +++ b/script/conformance-test @@ -0,0 +1,313 @@ +#!/bin/bash +set -e + +# Conformance test script for comparing MCP server behavior between branches +# Builds both main and current branch, runs various flag combinations, +# and produces a conformance report with timing and diffs. + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +REPORT_DIR="$PROJECT_DIR/conformance-report" +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +echo -e "${BLUE}=== MCP Server Conformance Test ===${NC}" +echo "Current branch: $CURRENT_BRANCH" +echo "Report directory: $REPORT_DIR" + +# Find the common ancestor +MERGE_BASE=$(git merge-base HEAD origin/main) +echo "Comparing against merge-base: $MERGE_BASE" +echo "" + +# Create report directory +rm -rf "$REPORT_DIR" +mkdir -p "$REPORT_DIR"/{main,branch,diffs} + +# Build binaries +echo -e "${YELLOW}Building binaries...${NC}" + +echo "Building current branch ($CURRENT_BRANCH)..." +go build -o "$REPORT_DIR/branch/github-mcp-server" ./cmd/github-mcp-server +BRANCH_BUILD_OK=$? + +echo "Building main branch (using temp worktree at merge-base)..." +TEMP_WORKTREE=$(mktemp -d) +git worktree add --quiet "$TEMP_WORKTREE" "$MERGE_BASE" +(cd "$TEMP_WORKTREE" && go build -o "$REPORT_DIR/main/github-mcp-server" ./cmd/github-mcp-server) +MAIN_BUILD_OK=$? +git worktree remove --force "$TEMP_WORKTREE" + +if [ $BRANCH_BUILD_OK -ne 0 ] || [ $MAIN_BUILD_OK -ne 0 ]; then + echo -e "${RED}Build failed!${NC}" + exit 1 +fi + +echo -e "${GREEN}Both binaries built successfully${NC}" +echo "" + +# MCP JSON-RPC messages +INIT_MSG='{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"conformance-test","version":"1.0.0"}}}' +INITIALIZED_MSG='{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}' +LIST_TOOLS_MSG='{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' +LIST_RESOURCES_MSG='{"jsonrpc":"2.0","id":3,"method":"resources/listTemplates","params":{}}' +LIST_PROMPTS_MSG='{"jsonrpc":"2.0","id":4,"method":"prompts/list","params":{}}' + +# Function to normalize JSON for comparison +# Sorts all arrays (including nested ones) and formats consistently +normalize_json() { + local file="$1" + if [ -s "$file" ]; then + # Deep sort: sort all arrays recursively, then sort keys + jq -S 'walk(if type == "array" then sort_by(tostring) else . end)' "$file" 2>/dev/null > "${file}.tmp" && mv "${file}.tmp" "$file" + fi +} + +# Function to run MCP server and capture output with timing +run_mcp_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with all list commands - each response is on its own line + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLS_MSG" + echo "$LIST_RESOURCES_MSG" + echo "$LIST_PROMPTS_MSG" + sleep 0.5 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + # Each line is a separate JSON response + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 2) echo "$line" | jq -S '.' > "${output_prefix}_tools.json" 2>/dev/null ;; + 3) echo "$line" | jq -S '.' > "${output_prefix}_resources.json" 2>/dev/null ;; + 4) echo "$line" | jq -S '.' > "${output_prefix}_prompts.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created (in case of errors or missing responses) + touch "${output_prefix}_initialize.json" "${output_prefix}_tools.json" \ + "${output_prefix}_resources.json" "${output_prefix}_prompts.json" + + # Normalize all JSON files for consistent comparison (sorts arrays, keys) + for endpoint in initialize tools resources prompts; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Test configurations - array of "name|flags" +declare -a TEST_CONFIGS=( + "default|" + "read-only|--read-only" + "dynamic-toolsets|--dynamic-toolsets" + "read-only+dynamic|--read-only --dynamic-toolsets" + "toolsets-repos|--toolsets=repos" + "toolsets-issues|--toolsets=issues" + "toolsets-pull_requests|--toolsets=pull_requests" + "toolsets-repos,issues|--toolsets=repos,issues" + "toolsets-all|--toolsets=all" + "tools-get_me|--tools=get_me" + "tools-get_me,list_issues|--tools=get_me,list_issues" + "toolsets-repos+read-only|--toolsets=repos --read-only" + "toolsets-all+dynamic|--toolsets=all --dynamic-toolsets" + "toolsets-repos+dynamic|--toolsets=repos --dynamic-toolsets" + "toolsets-repos,issues+dynamic|--toolsets=repos,issues --dynamic-toolsets" +) + +# Summary arrays +declare -a TEST_NAMES +declare -a MAIN_TIMES +declare -a BRANCH_TIMES +declare -a DIFF_STATUS + +echo -e "${YELLOW}Running conformance tests...${NC}" +echo "" + +for config in "${TEST_CONFIGS[@]}"; do + IFS='|' read -r test_name flags <<< "$config" + + echo -e "${BLUE}Test: ${test_name}${NC}" + echo " Flags: ${flags:-}" + + # Create output directories + mkdir -p "$REPORT_DIR/main/$test_name" + mkdir -p "$REPORT_DIR/branch/$test_name" + mkdir -p "$REPORT_DIR/diffs/$test_name" + + # Run main version + main_time=$(run_mcp_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + echo " Main: ${main_time}s" + + # Run branch version + branch_time=$(run_mcp_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + echo " Branch: ${branch_time}s" + + # Calculate time difference + time_diff=$(echo "$branch_time - $main_time" | bc) + if (( $(echo "$time_diff > 0" | bc -l) )); then + echo -e " Δ Time: ${RED}+${time_diff}s (slower)${NC}" + else + echo -e " Δ Time: ${GREEN}${time_diff}s (faster)${NC}" + fi + + # Generate diffs for each endpoint + has_diff=false + for endpoint in initialize tools resources prompts; do + main_file="$REPORT_DIR/main/$test_name/output_${endpoint}.json" + branch_file="$REPORT_DIR/branch/$test_name/output_${endpoint}.json" + diff_file="$REPORT_DIR/diffs/$test_name/${endpoint}.diff" + + if ! diff -u "$main_file" "$branch_file" > "$diff_file" 2>/dev/null; then + has_diff=true + lines=$(wc -l < "$diff_file" | tr -d ' ') + echo -e " ${YELLOW}${endpoint}: DIFF (${lines} lines)${NC}" + else + rm -f "$diff_file" # No diff, remove empty file + echo -e " ${GREEN}${endpoint}: OK${NC}" + fi + done + + # Store results + TEST_NAMES+=("$test_name") + MAIN_TIMES+=("$main_time") + BRANCH_TIMES+=("$branch_time") + if [ "$has_diff" = true ]; then + DIFF_STATUS+=("DIFF") + else + DIFF_STATUS+=("OK") + fi + + echo "" +done + +# Generate summary report +REPORT_FILE="$REPORT_DIR/CONFORMANCE_REPORT.md" + +cat > "$REPORT_FILE" << EOF +# MCP Server Conformance Report + +Generated: $(date) +Current Branch: $CURRENT_BRANCH +Compared Against: merge-base ($MERGE_BASE) + +## Summary + +| Test | Main Time | Branch Time | Δ Time | Status | +|------|-----------|-------------|--------|--------| +EOF + +total_main=0 +total_branch=0 +diff_count=0 +ok_count=0 + +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + main_t="${MAIN_TIMES[$i]}" + branch_t="${BRANCH_TIMES[$i]}" + status="${DIFF_STATUS[$i]}" + + delta=$(echo "$branch_t - $main_t" | bc) + if (( $(echo "$delta > 0" | bc -l) )); then + delta_str="+${delta}s" + else + delta_str="${delta}s" + fi + + if [ "$status" = "DIFF" ]; then + status_str="⚠️ DIFF" + ((diff_count++)) || true + else + status_str="✅ OK" + ((ok_count++)) || true + fi + + total_main=$(echo "$total_main + $main_t" | bc) + total_branch=$(echo "$total_branch + $branch_t" | bc) + + echo "| $name | ${main_t}s | ${branch_t}s | $delta_str | $status_str |" >> "$REPORT_FILE" +done + +total_delta=$(echo "$total_branch - $total_main" | bc) +if (( $(echo "$total_delta > 0" | bc -l) )); then + total_delta_str="+${total_delta}s" +else + total_delta_str="${total_delta}s" +fi + +cat >> "$REPORT_FILE" << EOF +| **TOTAL** | **${total_main}s** | **${total_branch}s** | **$total_delta_str** | **$ok_count OK / $diff_count DIFF** | + +## Statistics + +- **Tests Passed (no diff):** $ok_count +- **Tests with Differences:** $diff_count +- **Total Main Time:** ${total_main}s +- **Total Branch Time:** ${total_branch}s +- **Overall Time Delta:** $total_delta_str + +## Detailed Diffs + +EOF + +# Add diff details to report +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + status="${DIFF_STATUS[$i]}" + + if [ "$status" = "DIFF" ]; then + echo "### $name" >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + + for endpoint in initialize tools resources prompts; do + diff_file="$REPORT_DIR/diffs/$name/${endpoint}.diff" + if [ -f "$diff_file" ] && [ -s "$diff_file" ]; then + echo "#### ${endpoint}" >> "$REPORT_FILE" + echo '```diff' >> "$REPORT_FILE" + cat "$diff_file" >> "$REPORT_FILE" + echo '```' >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + fi + done + fi +done + +echo -e "${BLUE}=== Conformance Test Complete ===${NC}" +echo "" +echo -e "Report: ${GREEN}$REPORT_FILE${NC}" +echo "" +echo "Summary:" +echo " Tests passed: $ok_count" +echo " Tests with diffs: $diff_count" +echo " Total main time: ${total_main}s" +echo " Total branch time: ${total_branch}s" +echo " Time delta: $total_delta_str" + +if [ $diff_count -gt 0 ]; then + echo "" + echo -e "${YELLOW}⚠️ Some tests have differences. Review the diffs in:${NC}" + echo " $REPORT_DIR/diffs/" +fi From 1bbac360d7f835e1ba7d25068cb5ccb0f6e87223 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:14:30 +0100 Subject: [PATCH 17/27] Add tests for dynamic toolset management tools Tests cover: - list_available_toolsets: verifies toolsets are listed with enabled status - get_toolset_tools: verifies tools can be retrieved for a toolset - enable_toolset: verifies toolset can be enabled and marked as enabled - enable_toolset invalid: verifies proper error for non-existent toolset - toolsets enum: verifies tools have proper enum values in schema --- pkg/github/dynamic_tools_test.go | 231 +++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 pkg/github/dynamic_tools_test.go diff --git a/pkg/github/dynamic_tools_test.go b/pkg/github/dynamic_tools_test.go new file mode 100644 index 000000000..4558204dc --- /dev/null +++ b/pkg/github/dynamic_tools_test.go @@ -0,0 +1,231 @@ +package github + +import ( + "context" + "encoding/json" + "testing" + + "github.com/github/github-mcp-server/pkg/registry" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createDynamicRequest creates an MCP request with the given arguments for dynamic tools. +func createDynamicRequest(args map[string]any) *mcp.CallToolRequest { + argsJSON, _ := json.Marshal(args) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Arguments: json.RawMessage(argsJSON), + }, + } +} + +func TestDynamicTools_ListAvailableToolsets(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the list_available_toolsets tool + tool := ListAvailableToolsets() + handler := tool.Handler(deps) + + // Call the handler + result, err := handler(context.Background(), createDynamicRequest(map[string]any{})) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var toolsets []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &toolsets) + require.NoError(t, err) + + // Verify we got toolsets + assert.NotEmpty(t, toolsets, "should have available toolsets") + + // Find the repos toolset and verify it's not enabled + var reposToolset map[string]string + for _, ts := range toolsets { + if ts["name"] == "repos" { + reposToolset = ts + break + } + } + require.NotNil(t, reposToolset, "repos toolset should exist") + assert.Equal(t, "false", reposToolset["currently_enabled"], "repos should not be enabled initially") +} + +func TestDynamicTools_GetToolsetTools(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the get_toolset_tools tool + tool := GetToolsetsTools(reg) + handler := tool.Handler(deps) + + // Call the handler for repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var tools []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &tools) + require.NoError(t, err) + + // Verify we got tools for the repos toolset + assert.NotEmpty(t, tools, "repos toolset should have tools") + + // Verify at least get_commit is there (a repos toolset tool) + var foundGetCommit bool + for _, tool := range tools { + if tool["name"] == "get_commit" { + foundGetCommit = true + break + } + } + assert.True(t, foundGetCommit, "get_commit should be in repos toolset") +} + +func TestDynamicTools_EnableToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: NewBaseDeps(nil, nil, nil, nil, translations.NullTranslationHelper, FeatureFlags{}, 0), + T: translations.NullTranslationHelper, + } + + // Verify repos is not enabled initially + assert.False(t, reg.IsToolsetEnabled(registry.ToolsetID("repos"))) + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Enable the repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Verify the toolset is now enabled + assert.True(t, reg.IsToolsetEnabled(registry.ToolsetID("repos")), "repos should be enabled after enable_toolset") + + // Verify the success message + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "enabled") + + // Try enabling again - should say already enabled + result2, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + textContent2 := result2.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent2.Text, "already enabled") +} + +func TestDynamicTools_EnableToolset_InvalidToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewRegistry(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Registry: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Try to enable a non-existent toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "nonexistent", + })) + require.NoError(t, err) + require.NotNil(t, result) + + // Should be an error result + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "not found") +} + +func TestDynamicTools_ToolsetsEnum(t *testing.T) { + // Build a registry + reg := NewRegistry(translations.NullTranslationHelper).Build() + + // Get tools to verify they have proper enum values + tools := DynamicTools(reg) + + // Find enable_toolset and get_toolset_tools + for _, tool := range tools { + if tool.Tool.Name == "enable_toolset" || tool.Tool.Name == "get_toolset_tools" { + // Verify the toolset property has an enum + schema := tool.Tool.InputSchema.(*jsonschema.Schema) + toolsetProp := schema.Properties["toolset"] + require.NotNil(t, toolsetProp, "toolset property should exist") + assert.NotEmpty(t, toolsetProp.Enum, "toolset property should have enum values") + + // Verify repos is in the enum + var foundRepos bool + for _, v := range toolsetProp.Enum { + if v == registry.ToolsetID("repos") { + foundRepos = true + break + } + } + assert.True(t, foundRepos, "repos should be in toolset enum for %s", tool.Tool.Name) + } + } +} From 7eab16997297d84201277c6c31b2f58dcb53611e Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:19:10 +0100 Subject: [PATCH 18/27] Advertise all capabilities in dynamic toolsets mode In dynamic mode, explicitly set HasTools/HasResources/HasPrompts=true since toolsets with those capabilities can be enabled at runtime. This ensures clients know the server supports these features even when no tools/resources/prompts are initially registered. --- internal/ghmcp/server.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 973698743..e98637067 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -168,13 +168,23 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { } // Create the MCP server - ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{ + serverOpts := &mcp.ServerOptions{ Instructions: github.GenerateInstructions(instructionToolsets), Logger: cfg.Logger, CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { return clients.rest, nil }), - }) + } + + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.HasTools = true + serverOpts.HasResources = true + serverOpts.HasPrompts = true + } + + ghServer := github.NewServer(cfg.Version, serverOpts) // Add middlewares ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) From 3914a92ea43baf2d42db8301116e878e35100927 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:26:39 +0100 Subject: [PATCH 19/27] Improve conformance test with dynamic tool calls and JSON normalization - Add dynamic tool call testing (list_available_toolsets, get_toolset_tools, enable_toolset) - Parse and sort embedded JSON in text fields for proper comparison - Separate progress output (stderr) from summary (stdout) for CI - Add test type field to distinguish standard vs dynamic tests --- script/conformance-test | 241 ++++++++++++++++++++++++++++++---------- 1 file changed, 180 insertions(+), 61 deletions(-) diff --git a/script/conformance-test b/script/conformance-test index fec4cc573..3ff0a55c2 100755 --- a/script/conformance-test +++ b/script/conformance-test @@ -4,40 +4,49 @@ set -e # Conformance test script for comparing MCP server behavior between branches # Builds both main and current branch, runs various flag combinations, # and produces a conformance report with timing and diffs. +# +# Output: +# - Progress/status messages go to stderr (for visibility in CI) +# - Final report summary goes to stdout (for piping/capture) SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" PROJECT_DIR="$(dirname "$SCRIPT_DIR")" REPORT_DIR="$PROJECT_DIR/conformance-report" CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) -# Colors for output +# Colors for output (only used on stderr) RED='\033[0;31m' GREEN='\033[0;32m' YELLOW='\033[1;33m' BLUE='\033[0;34m' NC='\033[0m' # No Color -echo -e "${BLUE}=== MCP Server Conformance Test ===${NC}" -echo "Current branch: $CURRENT_BRANCH" -echo "Report directory: $REPORT_DIR" +# Helper to print to stderr +log() { + echo -e "$@" >&2 +} + +log "${BLUE}=== MCP Server Conformance Test ===${NC}" +log "Current branch: $CURRENT_BRANCH" +log "Report directory: $REPORT_DIR" # Find the common ancestor MERGE_BASE=$(git merge-base HEAD origin/main) -echo "Comparing against merge-base: $MERGE_BASE" -echo "" +log "Comparing against merge-base: $MERGE_BASE" +log "" # Create report directory rm -rf "$REPORT_DIR" mkdir -p "$REPORT_DIR"/{main,branch,diffs} # Build binaries -echo -e "${YELLOW}Building binaries...${NC}" +log "${YELLOW}Building binaries...${NC}" -echo "Building current branch ($CURRENT_BRANCH)..." +log "Building current branch ($CURRENT_BRANCH)..." go build -o "$REPORT_DIR/branch/github-mcp-server" ./cmd/github-mcp-server BRANCH_BUILD_OK=$? -echo "Building main branch (using temp worktree at merge-base)..." +log "Building main branch (using temp worktree at merge-base)..." TEMP_WORKTREE=$(mktemp -d) git worktree add --quiet "$TEMP_WORKTREE" "$MERGE_BASE" (cd "$TEMP_WORKTREE" && go build -o "$REPORT_DIR/main/github-mcp-server" ./cmd/github-mcp-server) @@ -45,12 +54,12 @@ MAIN_BUILD_OK=$? git worktree remove --force "$TEMP_WORKTREE" if [ $BRANCH_BUILD_OK -ne 0 ] || [ $MAIN_BUILD_OK -ne 0 ]; then - echo -e "${RED}Build failed!${NC}" + log "${RED}Build failed!${NC}" exit 1 fi -echo -e "${GREEN}Both binaries built successfully${NC}" -echo "" +log "${GREEN}Both binaries built successfully${NC}" +log "" # MCP JSON-RPC messages INIT_MSG='{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"conformance-test","version":"1.0.0"}}}' @@ -59,13 +68,40 @@ LIST_TOOLS_MSG='{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' LIST_RESOURCES_MSG='{"jsonrpc":"2.0","id":3,"method":"resources/listTemplates","params":{}}' LIST_PROMPTS_MSG='{"jsonrpc":"2.0","id":4,"method":"prompts/list","params":{}}' +# Dynamic toolset management tool calls (for dynamic mode testing) +LIST_TOOLSETS_MSG='{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' +GET_TOOLSET_TOOLS_MSG='{"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"name":"get_toolset_tools","arguments":{"toolset":"repos"}}}' +ENABLE_TOOLSET_MSG='{"jsonrpc":"2.0","id":12,"method":"tools/call","params":{"name":"enable_toolset","arguments":{"toolset":"repos"}}}' +LIST_TOOLSETS_AFTER_MSG='{"jsonrpc":"2.0","id":13,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' + # Function to normalize JSON for comparison # Sorts all arrays (including nested ones) and formats consistently +# Also handles embedded JSON strings in "text" fields (from tool call responses) normalize_json() { local file="$1" if [ -s "$file" ]; then - # Deep sort: sort all arrays recursively, then sort keys - jq -S 'walk(if type == "array" then sort_by(tostring) else . end)' "$file" 2>/dev/null > "${file}.tmp" && mv "${file}.tmp" "$file" + # First, try to parse and re-serialize any JSON embedded in text fields + # This handles tool call responses where the result is JSON-in-a-string + jq -S ' + # Function to sort arrays recursively + def deep_sort: + if type == "array" then + [.[] | deep_sort] | sort_by(tostring) + elif type == "object" then + to_entries | map(.value |= deep_sort) | from_entries + else + . + end; + + # Walk the structure, and for any "text" field that looks like JSON array/object, parse and sort it + walk( + if type == "object" and .text and (.text | type == "string") and ((.text | startswith("[")) or (.text | startswith("{"))) then + .text = ((.text | fromjson | deep_sort) | tojson) + else + . + end + ) | deep_sort + ' "$file" 2>/dev/null > "${file}.tmp" && mv "${file}.tmp" "$file" fi } @@ -118,23 +154,84 @@ run_mcp_test() { echo "$duration" } -# Test configurations - array of "name|flags" +# Function to run MCP server with dynamic tool calls (for dynamic mode testing) +run_mcp_dynamic_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with dynamic tool calls in sequence: + # 1. Initialize + # 2. List available toolsets (before enable) + # 3. Get tools for repos toolset + # 4. Enable repos toolset + # 5. List available toolsets (after enable - should show repos as enabled) + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLSETS_MSG" + sleep 0.1 + echo "$GET_TOOLSET_TOOLS_MSG" + sleep 0.1 + echo "$ENABLE_TOOLSET_MSG" + sleep 0.1 + echo "$LIST_TOOLSETS_AFTER_MSG" + sleep 0.3 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 10) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_before.json" 2>/dev/null ;; + 11) echo "$line" | jq -S '.' > "${output_prefix}_get_toolset_tools.json" 2>/dev/null ;; + 12) echo "$line" | jq -S '.' > "${output_prefix}_enable_toolset.json" 2>/dev/null ;; + 13) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_after.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created + touch "${output_prefix}_initialize.json" "${output_prefix}_list_toolsets_before.json" \ + "${output_prefix}_get_toolset_tools.json" "${output_prefix}_enable_toolset.json" \ + "${output_prefix}_list_toolsets_after.json" + + # Normalize all JSON files + for endpoint in initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Test configurations - array of "name|flags|type" +# type can be "standard" or "dynamic" (for dynamic tool call testing) declare -a TEST_CONFIGS=( - "default|" - "read-only|--read-only" - "dynamic-toolsets|--dynamic-toolsets" - "read-only+dynamic|--read-only --dynamic-toolsets" - "toolsets-repos|--toolsets=repos" - "toolsets-issues|--toolsets=issues" - "toolsets-pull_requests|--toolsets=pull_requests" - "toolsets-repos,issues|--toolsets=repos,issues" - "toolsets-all|--toolsets=all" - "tools-get_me|--tools=get_me" - "tools-get_me,list_issues|--tools=get_me,list_issues" - "toolsets-repos+read-only|--toolsets=repos --read-only" - "toolsets-all+dynamic|--toolsets=all --dynamic-toolsets" - "toolsets-repos+dynamic|--toolsets=repos --dynamic-toolsets" - "toolsets-repos,issues+dynamic|--toolsets=repos,issues --dynamic-toolsets" + "default||standard" + "read-only|--read-only|standard" + "dynamic-toolsets|--dynamic-toolsets|standard" + "read-only+dynamic|--read-only --dynamic-toolsets|standard" + "toolsets-repos|--toolsets=repos|standard" + "toolsets-issues|--toolsets=issues|standard" + "toolsets-pull_requests|--toolsets=pull_requests|standard" + "toolsets-repos,issues|--toolsets=repos,issues|standard" + "toolsets-all|--toolsets=all|standard" + "tools-get_me|--tools=get_me|standard" + "tools-get_me,list_issues|--tools=get_me,list_issues|standard" + "toolsets-repos+read-only|--toolsets=repos --read-only|standard" + "toolsets-all+dynamic|--toolsets=all --dynamic-toolsets|standard" + "toolsets-repos+dynamic|--toolsets=repos --dynamic-toolsets|standard" + "toolsets-repos,issues+dynamic|--toolsets=repos,issues --dynamic-toolsets|standard" + "dynamic-tool-calls|--dynamic-toolsets|dynamic" ) # Summary arrays @@ -143,39 +240,52 @@ declare -a MAIN_TIMES declare -a BRANCH_TIMES declare -a DIFF_STATUS -echo -e "${YELLOW}Running conformance tests...${NC}" -echo "" +log "${YELLOW}Running conformance tests...${NC}" +log "" for config in "${TEST_CONFIGS[@]}"; do - IFS='|' read -r test_name flags <<< "$config" + IFS='|' read -r test_name flags test_type <<< "$config" - echo -e "${BLUE}Test: ${test_name}${NC}" - echo " Flags: ${flags:-}" + log "${BLUE}Test: ${test_name}${NC}" + log " Flags: ${flags:-}" + log " Type: ${test_type}" # Create output directories mkdir -p "$REPORT_DIR/main/$test_name" mkdir -p "$REPORT_DIR/branch/$test_name" mkdir -p "$REPORT_DIR/diffs/$test_name" - # Run main version - main_time=$(run_mcp_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") - echo " Main: ${main_time}s" - - # Run branch version - branch_time=$(run_mcp_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") - echo " Branch: ${branch_time}s" + if [ "$test_type" = "dynamic" ]; then + # Run dynamic tool call test + main_time=$(run_mcp_dynamic_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_dynamic_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after" + else + # Run standard test + main_time=$(run_mcp_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize tools resources prompts" + fi # Calculate time difference time_diff=$(echo "$branch_time - $main_time" | bc) if (( $(echo "$time_diff > 0" | bc -l) )); then - echo -e " Δ Time: ${RED}+${time_diff}s (slower)${NC}" + log " Δ Time: ${RED}+${time_diff}s (slower)${NC}" else - echo -e " Δ Time: ${GREEN}${time_diff}s (faster)${NC}" + log " Δ Time: ${GREEN}${time_diff}s (faster)${NC}" fi # Generate diffs for each endpoint has_diff=false - for endpoint in initialize tools resources prompts; do + for endpoint in $endpoints; do main_file="$REPORT_DIR/main/$test_name/output_${endpoint}.json" branch_file="$REPORT_DIR/branch/$test_name/output_${endpoint}.json" diff_file="$REPORT_DIR/diffs/$test_name/${endpoint}.diff" @@ -183,10 +293,10 @@ for config in "${TEST_CONFIGS[@]}"; do if ! diff -u "$main_file" "$branch_file" > "$diff_file" 2>/dev/null; then has_diff=true lines=$(wc -l < "$diff_file" | tr -d ' ') - echo -e " ${YELLOW}${endpoint}: DIFF (${lines} lines)${NC}" + log " ${YELLOW}${endpoint}: DIFF (${lines} lines)${NC}" else rm -f "$diff_file" # No diff, remove empty file - echo -e " ${GREEN}${endpoint}: OK${NC}" + log " ${GREEN}${endpoint}: OK${NC}" fi done @@ -200,7 +310,7 @@ for config in "${TEST_CONFIGS[@]}"; do DIFF_STATUS+=("OK") fi - echo "" + log "" done # Generate summary report @@ -282,7 +392,8 @@ for i in "${!TEST_NAMES[@]}"; do echo "### $name" >> "$REPORT_FILE" echo "" >> "$REPORT_FILE" - for endpoint in initialize tools resources prompts; do + # Check all possible endpoints + for endpoint in initialize tools resources prompts list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do diff_file="$REPORT_DIR/diffs/$name/${endpoint}.diff" if [ -f "$diff_file" ] && [ -s "$diff_file" ]; then echo "#### ${endpoint}" >> "$REPORT_FILE" @@ -295,19 +406,27 @@ for i in "${!TEST_NAMES[@]}"; do fi done -echo -e "${BLUE}=== Conformance Test Complete ===${NC}" -echo "" -echo -e "Report: ${GREEN}$REPORT_FILE${NC}" -echo "" -echo "Summary:" -echo " Tests passed: $ok_count" -echo " Tests with diffs: $diff_count" -echo " Total main time: ${total_main}s" -echo " Total branch time: ${total_branch}s" -echo " Time delta: $total_delta_str" +log "${BLUE}=== Conformance Test Complete ===${NC}" +log "" +log "Report: ${GREEN}$REPORT_FILE${NC}" +log "" + +# Output summary to stdout (for CI capture) +echo "=== Conformance Test Summary ===" +echo "Tests passed: $ok_count" +echo "Tests with diffs: $diff_count" +echo "Total main time: ${total_main}s" +echo "Total branch time: ${total_branch}s" +echo "Time delta: $total_delta_str" if [ $diff_count -gt 0 ]; then + log "" + log "${YELLOW}⚠️ Some tests have differences. Review the diffs in:${NC}" + log " $REPORT_DIR/diffs/" + echo "" + echo "RESULT: DIFFERENCES FOUND" + # Don't exit with error - diffs may be intentional improvements +else echo "" - echo -e "${YELLOW}⚠️ Some tests have differences. Review the diffs in:${NC}" - echo " $REPORT_DIR/diffs/" + echo "RESULT: ALL TESTS PASSED" fi From eeb29847f09a1dc8f7681c55958e2042c5bda6a0 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:28:24 +0100 Subject: [PATCH 20/27] Add conformance-report to .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b018fafac..88525b387 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,4 @@ bin/ # binary github-mcp-server -.history \ No newline at end of file +.historyconformance-report/ From 4665d2301f8c4bdb56af23be5170d40e48f9479f Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:30:14 +0100 Subject: [PATCH 21/27] Add conformance test CI workflow - Runs on pull requests to main - Compares PR branch against merge-base with origin/main - Outputs full conformance report to GitHub Actions Job Summary - Uploads detailed report as artifact for deeper investigation - Does not fail the build on differences (may be intentional) --- .github/workflows/conformance.yml | 69 +++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 .github/workflows/conformance.yml diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 000000000..92524ea17 --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,69 @@ +name: Conformance Test + +on: + pull_request: + +permissions: + contents: read + +jobs: + conformance: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + # Fetch full history to access merge-base + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "go.mod" + + - name: Download dependencies + run: go mod download + + - name: Run conformance test + id: conformance + run: | + # Run conformance test, capture stdout for summary + script/conformance-test > conformance-summary.txt 2>&1 || true + + # Output the summary + cat conformance-summary.txt + + # Check result + if grep -q "RESULT: ALL TESTS PASSED" conformance-summary.txt; then + echo "status=passed" >> $GITHUB_OUTPUT + else + echo "status=differences" >> $GITHUB_OUTPUT + fi + + - name: Generate Job Summary + run: | + # Add the full markdown report to the job summary + echo "# MCP Server Conformance Report" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Comparing PR branch against merge-base with \`origin/main\`" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Extract and append the report content (skip the header since we added our own) + tail -n +5 conformance-report/CONFORMANCE_REPORT.md >> $GITHUB_STEP_SUMMARY + + echo "" >> $GITHUB_STEP_SUMMARY + echo "---" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Add interpretation note + if [ "${{ steps.conformance.outputs.status }}" = "passed" ]; then + echo "✅ **All conformance tests passed** - No behavioral differences detected." >> $GITHUB_STEP_SUMMARY + else + echo "⚠️ **Differences detected** - Review the diffs above to ensure changes are intentional." >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Common expected differences:" >> $GITHUB_STEP_SUMMARY + echo "- New tools/toolsets added" >> $GITHUB_STEP_SUMMARY + echo "- Tool descriptions updated" >> $GITHUB_STEP_SUMMARY + echo "- Capability changes (intentional improvements)" >> $GITHUB_STEP_SUMMARY + fi From c913d584aa51bd8b5eacb333fd2c4c077e4e6a7e Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 22:44:34 +0100 Subject: [PATCH 22/27] Add map indexes for O(1) lookups in Registry Address review feedback to use maps for collections. Added lookup maps (toolsByName, resourcesByURI, promptsByName) while keeping slices for ordered iteration. This provides O(1) lookup for: - FindToolByName - filterToolsByName (used by ForMCPRequest) - filterResourcesByURI - filterPromptsByName Maps are built once during Build() and shared in ForMCPRequest copies. --- .gitignore | 3 +- pkg/registry/builder.go | 91 ++++++++++++++++++++--------------- pkg/registry/filters.go | 5 +- pkg/registry/registry.go | 100 +++++++++++++-------------------------- 4 files changed, 91 insertions(+), 108 deletions(-) diff --git a/.gitignore b/.gitignore index 88525b387..5684108b0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ bin/ # binary github-mcp-server -.historyconformance-report/ +.history +conformance-report/ diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go index a8720a4f4..384128044 100644 --- a/pkg/registry/builder.go +++ b/pkg/registry/builder.go @@ -1,6 +1,7 @@ package registry import ( + "sort" "strings" ) @@ -124,8 +125,10 @@ func (b *Builder) Build() *Registry { featureChecker: b.featureChecker, } - // Process toolsets - r.enabledToolsets, r.unrecognizedToolsets = b.processToolsets() + // Note: toolsByName map is lazy-initialized on first use via getToolsByName() + + // Process toolsets and pre-compute metadata in a single pass + r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() // Process additional tools (resolve aliases) if len(b.additionalTools) > 0 { @@ -146,25 +149,65 @@ func (b *Builder) Build() *Registry { // processToolsets processes the toolsetIDs configuration and returns: // - enabledToolsets map (nil means all enabled) // - unrecognizedToolsets list for warnings -func (b *Builder) processToolsets() (map[ToolsetID]bool, []string) { - // Build a set of valid toolset IDs for validation +// - allToolsetIDs sorted list of all toolset IDs +// - defaultToolsetIDs sorted list of default toolset IDs +// - toolsetDescriptions map of toolset ID to description +func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, []ToolsetID, map[ToolsetID]string) { + // Single pass: collect all toolset metadata together validIDs := make(map[ToolsetID]bool) - for _, t := range b.tools { + defaultIDs := make(map[ToolsetID]bool) + descriptions := make(map[ToolsetID]string) + + for i := range b.tools { + t := &b.tools[i] validIDs[t.Toolset.ID] = true + if t.Toolset.Default { + defaultIDs[t.Toolset.ID] = true + } + if t.Toolset.Description != "" { + descriptions[t.Toolset.ID] = t.Toolset.Description + } } - for _, r := range b.resourceTemplates { + for i := range b.resourceTemplates { + r := &b.resourceTemplates[i] validIDs[r.Toolset.ID] = true + if r.Toolset.Default { + defaultIDs[r.Toolset.ID] = true + } + if r.Toolset.Description != "" { + descriptions[r.Toolset.ID] = r.Toolset.Description + } } - for _, p := range b.prompts { + for i := range b.prompts { + p := &b.prompts[i] validIDs[p.Toolset.ID] = true + if p.Toolset.Default { + defaultIDs[p.Toolset.ID] = true + } + if p.Toolset.Description != "" { + descriptions[p.Toolset.ID] = p.Toolset.Description + } + } + + // Build sorted slices from the collected maps + allToolsetIDs := make([]ToolsetID, 0, len(validIDs)) + for id := range validIDs { + allToolsetIDs = append(allToolsetIDs, id) + } + sort.Slice(allToolsetIDs, func(i, j int) bool { return allToolsetIDs[i] < allToolsetIDs[j] }) + + defaultToolsetIDList := make([]ToolsetID, 0, len(defaultIDs)) + for id := range defaultIDs { + defaultToolsetIDList = append(defaultToolsetIDList, id) } + sort.Slice(defaultToolsetIDList, func(i, j int) bool { return defaultToolsetIDList[i] < defaultToolsetIDList[j] }) toolsetIDs := b.toolsetIDs // Check for "all" keyword - enables all toolsets for _, id := range toolsetIDs { if strings.TrimSpace(id) == "all" { - return nil, nil // nil means all enabled + return nil, nil, allToolsetIDs, defaultToolsetIDList, descriptions // nil means all enabled } } @@ -184,7 +227,7 @@ func (b *Builder) processToolsets() (map[ToolsetID]bool, []string) { continue } if trimmed == "default" { - for _, defaultID := range b.defaultToolsetIDs() { + for _, defaultID := range defaultToolsetIDList { if !seen[defaultID] { seen[defaultID] = true expanded = append(expanded, defaultID) @@ -204,38 +247,12 @@ func (b *Builder) processToolsets() (map[ToolsetID]bool, []string) { } if len(expanded) == 0 { - return make(map[ToolsetID]bool), unrecognized + return make(map[ToolsetID]bool), unrecognized, allToolsetIDs, defaultToolsetIDList, descriptions } enabledToolsets := make(map[ToolsetID]bool, len(expanded)) for _, id := range expanded { enabledToolsets[id] = true } - return enabledToolsets, unrecognized -} - -// defaultToolsetIDs returns toolset IDs marked as Default in their metadata. -func (b *Builder) defaultToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range b.tools { - if b.tools[i].Toolset.Default { - seen[b.tools[i].Toolset.ID] = true - } - } - for i := range b.resourceTemplates { - if b.resourceTemplates[i].Toolset.Default { - seen[b.resourceTemplates[i].Toolset.ID] = true - } - } - for i := range b.prompts { - if b.prompts[i].Toolset.Default { - seen[b.prompts[i].Toolset.ID] = true - } - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - return ids + return enabledToolsets, unrecognized, allToolsetIDs, defaultToolsetIDList, descriptions } diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go index 2d46aab98..3b251e439 100644 --- a/pkg/registry/filters.go +++ b/pkg/registry/filters.go @@ -149,7 +149,7 @@ func (r *Registry) AvailablePrompts(ctx context.Context) []ServerPrompt { } // filterToolsByName returns tools matching the given name, checking deprecated aliases. -// Returns from the current tools slice (respects existing filter chain). +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). func (r *Registry) filterToolsByName(name string) []ServerTool { // First check for exact match for i := range r.tools { @@ -169,9 +169,9 @@ func (r *Registry) filterToolsByName(name string) []ServerTool { } // filterResourcesByURI returns resource templates matching the given URI pattern. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { for i := range r.resourceTemplates { - // Check if URI matches the template pattern (exact match on URITemplate string) if r.resourceTemplates[i].Template.URITemplate == uri { return []ServerResourceTemplate{r.resourceTemplates[i]} } @@ -180,6 +180,7 @@ func (r *Registry) filterResourcesByURI(uri string) []ServerResourceTemplate { } // filterPromptsByName returns prompts matching the given name. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). func (r *Registry) filterPromptsByName(name string) []ServerPrompt { for i := range r.prompts { if r.prompts[i].Prompt.Name == name { diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 7a73a6e5b..6888c4cb4 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -6,6 +6,7 @@ import ( "os" "slices" "sort" + "sync" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -25,15 +26,24 @@ import ( // - Lazy dependency injection during registration via RegisterAll() // - Runtime toolset enabling for dynamic toolsets mode type Registry struct { - // tools holds all tools in this group + // tools holds all tools in this group (ordered for iteration) tools []ServerTool - // resourceTemplates holds all resource templates in this group + // toolsByName provides O(1) lookup by tool name (lazy-initialized) + // Used by FindToolByName for repeated lookups in long-lived servers + toolsByName map[string]*ServerTool + toolsByNameOnce sync.Once + // resourceTemplates holds all resource templates in this group (ordered for iteration) resourceTemplates []ServerResourceTemplate - // prompts holds all prompts in this group + // prompts holds all prompts in this group (ordered for iteration) prompts []ServerPrompt // deprecatedAliases maps old tool names to new canonical names deprecatedAliases map[string]string + // Pre-computed toolset metadata (set during Build) + toolsetIDs []ToolsetID // sorted list of all toolset IDs + defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs + toolsetDescriptions map[ToolsetID]string // toolset ID -> description + // Filters - these control what's returned by Available* methods // readOnly when true filters out write tools readOnly bool @@ -57,6 +67,18 @@ func (r *Registry) UnrecognizedToolsets() []string { return r.unrecognizedToolsets } +// getToolsByName returns the toolsByName map, initializing it lazily on first call. +// Used by FindToolByName for O(1) lookups in long-lived servers with repeated lookups. +func (r *Registry) getToolsByName() map[string]*ServerTool { + r.toolsByNameOnce.Do(func() { + r.toolsByName = make(map[string]*ServerTool, len(r.tools)) + for i := range r.tools { + r.toolsByName[r.tools[i].Tool.Name] = &r.tools[i] + } + }) + return r.toolsByName +} + // MCP method constants for use with ForMCPRequest. const ( MCPMethodInitialize = "initialize" @@ -90,6 +112,8 @@ const ( // All existing filters (read-only, toolsets, etc.) still apply to the returned items. func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { // Create a shallow copy with shared filter settings + // Note: lazy-init maps (toolsByName, etc.) are NOT copied - the new Registry + // will initialize its own maps on first use if needed result := &Registry{ tools: r.tools, resourceTemplates: r.resourceTemplates, @@ -142,75 +166,18 @@ func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { // ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. func (r *Registry) ToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range r.tools { - seen[r.tools[i].Toolset.ID] = true - } - for i := range r.resourceTemplates { - seen[r.resourceTemplates[i].Toolset.ID] = true - } - for i := range r.prompts { - seen[r.prompts[i].Toolset.ID] = true - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids + return r.toolsetIDs } // DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. // The IDs are returned in sorted order for deterministic output. func (r *Registry) DefaultToolsetIDs() []ToolsetID { - seen := make(map[ToolsetID]bool) - for i := range r.tools { - if r.tools[i].Toolset.Default { - seen[r.tools[i].Toolset.ID] = true - } - } - for i := range r.resourceTemplates { - if r.resourceTemplates[i].Toolset.Default { - seen[r.resourceTemplates[i].Toolset.ID] = true - } - } - for i := range r.prompts { - if r.prompts[i].Toolset.Default { - seen[r.prompts[i].Toolset.ID] = true - } - } - - ids := make([]ToolsetID, 0, len(seen)) - for id := range seen { - ids = append(ids, id) - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids + return r.defaultToolsetIDs } // ToolsetDescriptions returns a map of toolset ID to description for all toolsets. func (r *Registry) ToolsetDescriptions() map[ToolsetID]string { - descriptions := make(map[ToolsetID]string) - for i := range r.tools { - t := &r.tools[i] - if t.Toolset.Description != "" { - descriptions[t.Toolset.ID] = t.Toolset.Description - } - } - for i := range r.resourceTemplates { - res := &r.resourceTemplates[i] - if res.Toolset.Description != "" { - descriptions[res.Toolset.ID] = res.Toolset.Description - } - } - for i := range r.prompts { - p := &r.prompts[i] - if p.Toolset.Description != "" { - descriptions[p.Toolset.ID] = p.Toolset.Description - } - } - return descriptions + return r.toolsetDescriptions } // RegisterTools registers all available tools with the server using the provided dependencies. @@ -269,11 +236,8 @@ func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, al // Returns the tool, its toolset ID, and an error if not found. // This searches ALL tools regardless of filters. func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { - for i := range r.tools { - tool := &r.tools[i] - if tool.Tool.Name == toolName { - return tool, tool.Toolset.ID, nil - } + if tool, ok := r.getToolsByName()[toolName]; ok { + return tool, tool.Toolset.ID, nil } return nil, "", NewToolDoesNotExistError(toolName) } From 7864d8059d18435a758f1944e87a1002d236654c Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 23:05:21 +0100 Subject: [PATCH 23/27] perf(registry): O(1) HasToolset lookup via pre-computed set Add toolsetIDSet (map[ToolsetID]bool) to Registry for O(1) HasToolset lookups. Previously HasToolset iterated through all tools, resourceTemplates, and prompts to check if any belonged to the given toolset. Now it's a simple map lookup. The set is populated during the single-pass processToolsets() call, which already collected all valid toolset IDs. This adds zero new iteration - just returns the existing validIDs map. processToolsets now returns 6 values: - enabledToolsets, unrecognized, toolsetIDs, toolsetIDSet, defaultToolsetIDs, descriptions --- pkg/registry/builder.go | 11 ++++++----- pkg/registry/registry.go | 18 ++---------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go index 384128044..33c12d697 100644 --- a/pkg/registry/builder.go +++ b/pkg/registry/builder.go @@ -128,7 +128,7 @@ func (b *Builder) Build() *Registry { // Note: toolsByName map is lazy-initialized on first use via getToolsByName() // Process toolsets and pre-compute metadata in a single pass - r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() + r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.toolsetIDSet, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() // Process additional tools (resolve aliases) if len(b.additionalTools) > 0 { @@ -150,9 +150,10 @@ func (b *Builder) Build() *Registry { // - enabledToolsets map (nil means all enabled) // - unrecognizedToolsets list for warnings // - allToolsetIDs sorted list of all toolset IDs +// - toolsetIDSet map for O(1) HasToolset lookup // - defaultToolsetIDs sorted list of default toolset IDs // - toolsetDescriptions map of toolset ID to description -func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, []ToolsetID, map[ToolsetID]string) { +func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, map[ToolsetID]bool, []ToolsetID, map[ToolsetID]string) { // Single pass: collect all toolset metadata together validIDs := make(map[ToolsetID]bool) defaultIDs := make(map[ToolsetID]bool) @@ -207,7 +208,7 @@ func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, // Check for "all" keyword - enables all toolsets for _, id := range toolsetIDs { if strings.TrimSpace(id) == "all" { - return nil, nil, allToolsetIDs, defaultToolsetIDList, descriptions // nil means all enabled + return nil, nil, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions // nil means all enabled } } @@ -247,12 +248,12 @@ func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, } if len(expanded) == 0 { - return make(map[ToolsetID]bool), unrecognized, allToolsetIDs, defaultToolsetIDList, descriptions + return make(map[ToolsetID]bool), unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions } enabledToolsets := make(map[ToolsetID]bool, len(expanded)) for _, id := range expanded { enabledToolsets[id] = true } - return enabledToolsets, unrecognized, allToolsetIDs, defaultToolsetIDList, descriptions + return enabledToolsets, unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions } diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 6888c4cb4..3ac3022c9 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -41,6 +41,7 @@ type Registry struct { // Pre-computed toolset metadata (set during Build) toolsetIDs []ToolsetID // sorted list of all toolset IDs + toolsetIDSet map[ToolsetID]bool // set for O(1) HasToolset lookup defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs toolsetDescriptions map[ToolsetID]string // toolset ID -> description @@ -244,22 +245,7 @@ func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, erro // HasToolset checks if any tool/resource/prompt belongs to the given toolset. func (r *Registry) HasToolset(toolsetID ToolsetID) bool { - for i := range r.tools { - if r.tools[i].Toolset.ID == toolsetID { - return true - } - } - for i := range r.resourceTemplates { - if r.resourceTemplates[i].Toolset.ID == toolsetID { - return true - } - } - for i := range r.prompts { - if r.prompts[i].Toolset.ID == toolsetID { - return true - } - } - return false + return r.toolsetIDSet[toolsetID] } // AllTools returns all tools without any filtering, sorted deterministically. From 9aeded421a082f7ade5903a1447696870fa97e53 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 15 Dec 2025 23:12:11 +0100 Subject: [PATCH 24/27] simplify: remove lazy toolsByName map - not needed for actual use cases FindToolByName() is only called once per request at most (to find toolset ID for dynamic enablement). The SDK handles tool dispatch after registration. A simple linear scan over ~90 tools is trivially fast and avoids: - sync.Once complexity - Map allocation - Premature optimization for non-existent 'repeated lookups' The pre-computed maps we keep (toolsetIDSet, etc.) are justified because they're used for filtering logic that runs on every request. --- pkg/registry/builder.go | 2 -- pkg/registry/registry.go | 23 ++++------------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go index 33c12d697..8680f5c3d 100644 --- a/pkg/registry/builder.go +++ b/pkg/registry/builder.go @@ -125,8 +125,6 @@ func (b *Builder) Build() *Registry { featureChecker: b.featureChecker, } - // Note: toolsByName map is lazy-initialized on first use via getToolsByName() - // Process toolsets and pre-compute metadata in a single pass r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.toolsetIDSet, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 3ac3022c9..56b2fdee5 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -6,7 +6,6 @@ import ( "os" "slices" "sort" - "sync" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -28,10 +27,6 @@ import ( type Registry struct { // tools holds all tools in this group (ordered for iteration) tools []ServerTool - // toolsByName provides O(1) lookup by tool name (lazy-initialized) - // Used by FindToolByName for repeated lookups in long-lived servers - toolsByName map[string]*ServerTool - toolsByNameOnce sync.Once // resourceTemplates holds all resource templates in this group (ordered for iteration) resourceTemplates []ServerResourceTemplate // prompts holds all prompts in this group (ordered for iteration) @@ -68,18 +63,6 @@ func (r *Registry) UnrecognizedToolsets() []string { return r.unrecognizedToolsets } -// getToolsByName returns the toolsByName map, initializing it lazily on first call. -// Used by FindToolByName for O(1) lookups in long-lived servers with repeated lookups. -func (r *Registry) getToolsByName() map[string]*ServerTool { - r.toolsByNameOnce.Do(func() { - r.toolsByName = make(map[string]*ServerTool, len(r.tools)) - for i := range r.tools { - r.toolsByName[r.tools[i].Tool.Name] = &r.tools[i] - } - }) - return r.toolsByName -} - // MCP method constants for use with ForMCPRequest. const ( MCPMethodInitialize = "initialize" @@ -237,8 +220,10 @@ func (r *Registry) ResolveToolAliases(toolNames []string) (resolved []string, al // Returns the tool, its toolset ID, and an error if not found. // This searches ALL tools regardless of filters. func (r *Registry) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { - if tool, ok := r.getToolsByName()[toolName]; ok { - return tool, tool.Toolset.ID, nil + for i := range r.tools { + if r.tools[i].Tool.Name == toolName { + return &r.tools[i], r.tools[i].Toolset.ID, nil + } } return nil, "", NewToolDoesNotExistError(toolName) } From 7070d03478ef7d75e5f8a5df31833794b5a98e43 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Dec 2025 10:14:31 +0000 Subject: [PATCH 25/27] Add generic tool filtering mechanisms to registry package - Add Enabled field to ServerTool for self-filtering based on context - Add ToolFilter type and WithFilter method to Builder for cross-cutting filters - Update isToolEnabled to check Enabled function and builder filters in order: 1. Tool's Enabled function 2. Feature flags (FeatureFlagEnable/FeatureFlagDisable) 3. Read-only filter 4. Builder filters 5. Toolset/additional tools check - Add FilteredTools method to Registry as alias for AvailableTools - Add comprehensive tests for all new functionality - All tests pass and linter is clean Closes #1618 Co-authored-by: SamMorrowDrums <4811358+SamMorrowDrums@users.noreply.github.com> --- pkg/registry/builder.go | 17 ++ pkg/registry/filters.go | 35 ++- pkg/registry/registry.go | 4 + pkg/registry/registry_test.go | 469 ++++++++++++++++++++++++++++++++++ pkg/registry/server_tool.go | 7 + 5 files changed, 529 insertions(+), 3 deletions(-) diff --git a/pkg/registry/builder.go b/pkg/registry/builder.go index 8680f5c3d..813712064 100644 --- a/pkg/registry/builder.go +++ b/pkg/registry/builder.go @@ -1,10 +1,15 @@ package registry import ( + "context" "sort" "strings" ) +// ToolFilter is a function that determines if a tool should be included. +// Returns true if the tool should be included, false to exclude it. +type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error) + // Builder builds a Registry with the specified configuration. // Use NewBuilder to create a builder, chain configuration methods, // then call Build() to create the final Registry. @@ -19,6 +24,7 @@ import ( // WithReadOnly(true). // WithToolsets([]string{"repos", "issues"}). // WithFeatureChecker(checker). +// WithFilter(myFilter). // Build() type Builder struct { tools []ServerTool @@ -32,6 +38,7 @@ type Builder struct { toolsetIDsIsNil bool // tracks if nil was passed (nil = defaults) additionalTools []string // raw input, processed at Build() featureChecker FeatureFlagChecker + filters []ToolFilter // filters to apply to all tools } // NewBuilder creates a new Builder. @@ -111,6 +118,15 @@ func (b *Builder) WithFeatureChecker(checker FeatureFlagChecker) *Builder { return b } +// WithFilter adds a filter function that will be applied to all tools. +// Multiple filters can be added and are evaluated in order. +// If any filter returns false or an error, the tool is excluded. +// Returns self for chaining. +func (b *Builder) WithFilter(filter ToolFilter) *Builder { + b.filters = append(b.filters, filter) + return b +} + // Build creates the final Registry with all configuration applied. // This processes toolset filtering, tool name resolution, and sets up // the registry for use. The returned Registry is ready for use with @@ -123,6 +139,7 @@ func (b *Builder) Build() *Registry { deprecatedAliases: b.deprecatedAliases, readOnly: b.readOnly, featureChecker: b.featureChecker, + filters: b.filters, } // Process toolsets and pre-compute metadata in a single pass diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go index 3b251e439..d90956e1c 100644 --- a/pkg/registry/filters.go +++ b/pkg/registry/filters.go @@ -52,14 +52,36 @@ func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disable // isToolEnabled checks if a specific tool is enabled based on current filters. func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { - // Check read-only filter first (applies to all tools) - if r.readOnly && !tool.IsReadOnly() { - return false + // Check tool's own Enabled function first + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.Enabled check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } } // Check feature flags if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { return false } + // Check read-only filter (applies to all tools) + if r.readOnly && !tool.IsReadOnly() { + return false + } + // Apply builder filters + for _, filter := range r.filters { + allowed, err := filter(ctx, tool) + if err != nil { + fmt.Fprintf(os.Stderr, "Builder filter error for tool %q: %v\n", tool.Tool.Name, err) + return false + } + if !allowed { + return false + } + } // Check if tool is in additionalTools (bypasses toolset filter) if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { return true @@ -245,3 +267,10 @@ func (r *Registry) EnabledToolsetIDs() []ToolsetID { sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) return ids } + +// FilteredTools returns tools filtered by the Enabled function and builder filters. +// This is an alias for AvailableTools for clarity when focusing on filtering behavior. +// The context is used for Enabled function evaluation and builder filter checks. +func (r *Registry) FilteredTools(ctx context.Context) ([]ServerTool, error) { + return r.AvailableTools(ctx), nil +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 56b2fdee5..7376a4c9c 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -53,6 +53,9 @@ type Registry struct { // Takes context and flag name, returns (enabled, error). If error, log and treat as false. // If checker is nil, all flag checks return false. featureChecker FeatureFlagChecker + // filters are functions that will be applied to all tools during filtering. + // If any filter returns false or an error, the tool is excluded. + filters []ToolFilter // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets unrecognizedToolsets []string } @@ -107,6 +110,7 @@ func (r *Registry) ForMCPRequest(method string, itemName string) *Registry { enabledToolsets: r.enabledToolsets, // shared, not modified additionalTools: r.additionalTools, // shared, not modified featureChecker: r.featureChecker, + filters: r.filters, // shared, not modified unrecognizedToolsets: r.unrecognizedToolsets, } diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index 4518ed8bf..44ed3d773 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -1174,3 +1174,472 @@ func TestServerToolHandlerPanicOnNil(t *testing.T) { tool.Handler(nil) } + +// Tests for Enabled function on ServerTool +func TestServerToolEnabled(t *testing.T) { + tests := []struct { + name string + enabledFunc func(ctx context.Context) (bool, error) + expectedCount int + expectInResult bool + }{ + { + name: "nil Enabled function - tool included", + enabledFunc: nil, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns true - tool included", + enabledFunc: func(_ context.Context) (bool, error) { + return true, nil + }, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns false - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, nil + }, + expectedCount: 0, + expectInResult: false, + }, + { + name: "Enabled returns error - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, fmt.Errorf("simulated error") + }, + expectedCount: 0, + expectInResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = tt.enabledFunc + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != tt.expectedCount { + t.Errorf("Expected %d tools, got %d", tt.expectedCount, len(available)) + } + + found := false + for _, t := range available { + if t.Tool.Name == "test_tool" { + found = true + break + } + } + if found != tt.expectInResult { + t.Errorf("Expected tool in result: %v, got: %v", tt.expectInResult, found) + } + }) + } +} + +func TestServerToolEnabledWithContext(t *testing.T) { + type contextKey string + const userKey contextKey = "user" + + // Tool that checks context for user + tool := mockTool("context_aware_tool", "toolset1", true) + tool.Enabled = func(ctx context.Context) (bool, error) { + user := ctx.Value(userKey) + return user != nil && user.(string) == "authorized", nil + } + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + + // Without user in context - tool should be excluded + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools without user, got %d", len(available)) + } + + // With authorized user - tool should be included + ctxWithUser := context.WithValue(context.Background(), userKey, "authorized") + availableWithUser := reg.AvailableTools(ctxWithUser) + if len(availableWithUser) != 1 { + t.Errorf("Expected 1 tool with authorized user, got %d", len(availableWithUser)) + } + + // With unauthorized user - tool should be excluded + ctxWithBadUser := context.WithValue(context.Background(), userKey, "unauthorized") + availableWithBadUser := reg.AvailableTools(ctxWithBadUser) + if len(availableWithBadUser) != 0 { + t.Errorf("Expected 0 tools with unauthorized user, got %d", len(availableWithBadUser)) + } +} + +// Tests for WithFilter builder method +func TestBuilderWithFilter(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + } + + // Filter that excludes tool2 + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter, got %d", len(available)) + } + + for _, tool := range available { + if tool.Tool.Name == "tool2" { + t.Error("tool2 should have been filtered out") + } + } +} + +func TestBuilderWithMultipleFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + mockTool("tool4", "toolset1", true), + } + + // First filter excludes tool2 + filter1 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + // Second filter excludes tool3 + filter2 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool3", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter1). + WithFilter(filter2). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after multiple filters, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["tool1"] || !toolNames["tool4"] { + t.Error("Expected tool1 and tool4 to be available") + } + if toolNames["tool2"] || toolNames["tool3"] { + t.Error("tool2 and tool3 should have been filtered out") + } +} + +func TestBuilderFilterError(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Filter that returns an error + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, fmt.Errorf("filter error") + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when filter returns error, got %d", len(available)) + } +} + +func TestBuilderFilterWithContext(t *testing.T) { + type contextKey string + const scopeKey contextKey = "scope" + + tools := []ServerTool{ + mockTool("public_tool", "toolset1", true), + mockTool("private_tool", "toolset1", true), + } + + // Filter that checks context for scope + filter := func(ctx context.Context, tool *ServerTool) (bool, error) { + scope := ctx.Value(scopeKey) + if scope == "public" && tool.Tool.Name == "private_tool" { + return false, nil + } + return true, nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + // With public scope - private_tool should be excluded + ctxPublic := context.WithValue(context.Background(), scopeKey, "public") + availablePublic := reg.AvailableTools(ctxPublic) + if len(availablePublic) != 1 { + t.Fatalf("Expected 1 tool with public scope, got %d", len(availablePublic)) + } + if availablePublic[0].Tool.Name != "public_tool" { + t.Error("Expected only public_tool to be available") + } + + // With private scope - both tools should be available + ctxPrivate := context.WithValue(context.Background(), scopeKey, "private") + availablePrivate := reg.AvailableTools(ctxPrivate) + if len(availablePrivate) != 2 { + t.Errorf("Expected 2 tools with private scope, got %d", len(availablePrivate)) + } +} + +// Tests for interaction between Enabled, feature flags, and filters +func TestEnabledAndFeatureFlagInteraction(t *testing.T) { + // Tool with both Enabled function and feature flag + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Feature flag not enabled - tool should be excluded despite Enabled returning true + reg1 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + available1 := reg1.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when feature flag is not enabled") + } + + // Feature flag enabled - tool should be included + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 1 { + t.Error("Tool should be included when both Enabled and feature flag pass") + } + + // Enabled returns false - tool should be excluded despite feature flag + tool.Enabled = func(_ context.Context) (bool, error) { + return false, nil + } + reg3 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available3 := reg3.AvailableTools(context.Background()) + if len(available3) != 0 { + t.Error("Tool should be excluded when Enabled returns false") + } +} + +func TestEnabledAndBuilderFilterInteraction(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Filter that excludes the tool + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Error("Tool should be excluded when filter returns false, despite Enabled returning true") + } +} + +func TestAllFiltersInteraction(t *testing.T) { + // Tool with Enabled, feature flag, and subject to builder filter + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return true, nil + } + + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + + // All conditions pass - tool should be included + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Error("Tool should be included when all filters pass") + } + + // Change filter to return false - tool should be excluded + filterFalse := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filterFalse). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 0 { + t.Error("Tool should be excluded when any filter fails") + } +} + +// Test FilteredTools method +func TestFilteredTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + } + + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name == "tool1", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + filtered, err := reg.FilteredTools(context.Background()) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + if len(filtered) != 1 { + t.Fatalf("Expected 1 filtered tool, got %d", len(filtered)) + } + + if filtered[0].Tool.Name != "tool1" { + t.Errorf("Expected tool1, got %s", filtered[0].Tool.Name) + } +} + +func TestFilteredToolsMatchesAvailableTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"toolset1"}). + WithReadOnly(true). + Build() + + ctx := context.Background() + filtered, err := reg.FilteredTools(ctx) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + available := reg.AvailableTools(ctx) + + // Both methods should return the same results + if len(filtered) != len(available) { + t.Errorf("FilteredTools and AvailableTools returned different counts: %d vs %d", + len(filtered), len(available)) + } + + for i := range filtered { + if filtered[i].Tool.Name != available[i].Tool.Name { + t.Errorf("Tool at index %d differs: FilteredTools=%s, AvailableTools=%s", + i, filtered[i].Tool.Name, available[i].Tool.Name) + } + } +} + +func TestFilteringOrder(t *testing.T) { + // Test that filters are applied in the correct order: + // 1. Tool.Enabled + // 2. Feature flags + // 3. Read-only + // 4. Builder filters + // 5. Toolset/additional tools + + callOrder := []string{} + + tool := mockToolWithFlags("test_tool", "toolset1", false, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "Enabled") + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + callOrder = append(callOrder, "Filter") + return true, nil + } + + checker := func(_ context.Context, _ string) (bool, error) { + callOrder = append(callOrder, "FeatureFlag") + return true, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithReadOnly(true). // This will exclude the tool (it's not read-only) + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + _ = reg.AvailableTools(context.Background()) + + // Expected order: Enabled, FeatureFlag, ReadOnly (stops here because it's write tool) + expectedOrder := []string{"Enabled", "FeatureFlag"} + if len(callOrder) != len(expectedOrder) { + t.Errorf("Expected %d checks, got %d: %v", len(expectedOrder), len(callOrder), callOrder) + } + + for i, expected := range expectedOrder { + if i >= len(callOrder) || callOrder[i] != expected { + t.Errorf("At position %d: expected %s, got %v", i, expected, callOrder) + } + } +} diff --git a/pkg/registry/server_tool.go b/pkg/registry/server_tool.go index af00a2bed..3145b693d 100644 --- a/pkg/registry/server_tool.go +++ b/pkg/registry/server_tool.go @@ -52,6 +52,13 @@ type ServerTool struct { // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool // to be omitted. Used to disable tools when a feature flag is on. FeatureFlagDisable string + + // Enabled is an optional function called at build/filter time to determine + // if this tool should be available. If nil, the tool is considered enabled + // (subject to FeatureFlagEnable/FeatureFlagDisable checks). + // The context carries request-scoped information for the consumer to use. + // Returns (enabled, error). On error, the tool should be treated as disabled. + Enabled func(ctx context.Context) (bool, error) } // IsReadOnly returns true if this tool is marked as read-only via annotations. From d76ca286dbed736835d5a43b1dbc85f6d7882432 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 16 Dec 2025 19:30:58 +0100 Subject: [PATCH 26/27] docs: improve filter evaluation order and FilteredTools documentation - Add numbered filter evaluation order to isToolEnabled function doc - Number inline comments for each filter step (1-5) - Clarify FilteredTools error return is for future extensibility - Document that library consumers may need to surface recoverable errors Addresses review feedback on PR #1620 --- pkg/registry/filters.go | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/pkg/registry/filters.go b/pkg/registry/filters.go index d90956e1c..5bce63570 100644 --- a/pkg/registry/filters.go +++ b/pkg/registry/filters.go @@ -51,8 +51,14 @@ func (r *Registry) isFeatureFlagAllowed(ctx context.Context, enableFlag, disable } // isToolEnabled checks if a specific tool is enabled based on current filters. +// Filter evaluation order: +// 1. Tool.Enabled (tool self-filtering) +// 2. FeatureFlagEnable/FeatureFlagDisable +// 3. Read-only filter +// 4. Builder filters (via WithFilter) +// 5. Toolset/additional tools func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { - // Check tool's own Enabled function first + // 1. Check tool's own Enabled function first if tool.Enabled != nil { enabled, err := tool.Enabled(ctx) if err != nil { @@ -63,15 +69,15 @@ func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { return false } } - // Check feature flags + // 2. Check feature flags if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { return false } - // Check read-only filter (applies to all tools) + // 3. Check read-only filter (applies to all tools) if r.readOnly && !tool.IsReadOnly() { return false } - // Apply builder filters + // 4. Apply builder filters for _, filter := range r.filters { allowed, err := filter(ctx, tool) if err != nil { @@ -82,11 +88,11 @@ func (r *Registry) isToolEnabled(ctx context.Context, tool *ServerTool) bool { return false } } - // Check if tool is in additionalTools (bypasses toolset filter) + // 5. Check if tool is in additionalTools (bypasses toolset filter) if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { return true } - // Check toolset filter + // 5. Check toolset filter if !r.isToolsetEnabled(tool.Toolset.ID) { return false } @@ -269,7 +275,14 @@ func (r *Registry) EnabledToolsetIDs() []ToolsetID { } // FilteredTools returns tools filtered by the Enabled function and builder filters. -// This is an alias for AvailableTools for clarity when focusing on filtering behavior. +// This provides an explicit API for accessing filtered tools, currently implemented +// as an alias for AvailableTools. +// +// The error return is currently always nil but is included for future extensibility. +// Library consumers (e.g., remote server implementations) may need to surface +// recoverable filter errors rather than silently logging them. Having the error +// return in the API now avoids breaking changes later. +// // The context is used for Enabled function evaluation and builder filter checks. func (r *Registry) FilteredTools(ctx context.Context) ([]ServerTool, error) { return r.AvailableTools(ctx), nil From c274d78bbfe331673fca89557043ee4119ff1482 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 15 Dec 2025 22:42:09 +0000 Subject: [PATCH 27/27] Refactor GenerateToolsetsHelp() to use strings.Builder pattern Co-authored-by: SamMorrowDrums <4811358+SamMorrowDrums@users.noreply.github.com> --- pkg/github/tools.go | 57 ++++++++++++++++++++++++---------------- pkg/github/tools_test.go | 31 ++++++++++++++++++++++ 2 files changed, 65 insertions(+), 23 deletions(-) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index bc06f793f..1fde6e6bb 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -2,7 +2,6 @@ package github import ( "context" - "fmt" "strings" "github.com/github/github-mcp-server/pkg/registry" @@ -266,17 +265,19 @@ func GenerateToolsetsHelp() string { // Get toolset group to derive defaults and available toolsets r := NewRegistry(stubTranslator).Build() - // Format default tools from metadata + // Format default tools from metadata using strings.Builder + var defaultBuf strings.Builder defaultIDs := r.DefaultToolsetIDs() - defaultStrings := make([]string, len(defaultIDs)) for i, id := range defaultIDs { - defaultStrings[i] = string(id) + if i > 0 { + defaultBuf.WriteString(", ") + } + defaultBuf.WriteString(string(id)) } - defaultTools := strings.Join(defaultStrings, ", ") // Get all available toolsets (excludes context and dynamic for display) allToolsets := r.AvailableToolsets("context", "dynamic") - var availableToolsLines []string + var availableBuf strings.Builder const maxLineLength = 70 currentLine := "" @@ -288,27 +289,37 @@ func GenerateToolsetsHelp() string { case len(currentLine)+len(id)+2 <= maxLineLength: currentLine += ", " + id default: - availableToolsLines = append(availableToolsLines, currentLine) + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) currentLine = id } } if currentLine != "" { - availableToolsLines = append(availableToolsLines, currentLine) - } - - availableTools := strings.Join(availableToolsLines, ",\n\t ") - - toolsetsHelp := fmt.Sprintf("Comma-separated list of tool groups to enable (no spaces).\n"+ - "Available: %s\n", availableTools) + - "Special toolset keywords:\n" + - " - all: Enables all available toolsets\n" + - fmt.Sprintf(" - default: Enables the default toolset configuration of:\n\t %s\n", defaultTools) + - "Examples:\n" + - " - --toolsets=actions,gists,notifications\n" + - " - Default + additional: --toolsets=default,actions,gists\n" + - " - All tools: --toolsets=all" - - return toolsetsHelp + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) + } + + // Build the complete help text using strings.Builder + var buf strings.Builder + buf.WriteString("Comma-separated list of tool groups to enable (no spaces).\n") + buf.WriteString("Available: ") + buf.WriteString(availableBuf.String()) + buf.WriteString("\n") + buf.WriteString("Special toolset keywords:\n") + buf.WriteString(" - all: Enables all available toolsets\n") + buf.WriteString(" - default: Enables the default toolset configuration of:\n\t ") + buf.WriteString(defaultBuf.String()) + buf.WriteString("\n") + buf.WriteString("Examples:\n") + buf.WriteString(" - --toolsets=actions,gists,notifications\n") + buf.WriteString(" - Default + additional: --toolsets=default,actions,gists\n") + buf.WriteString(" - All tools: --toolsets=all") + + return buf.String() } // stubTranslator is a passthrough translator for cases where we need a Registry diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go index 4e6d91980..80270d2bc 100644 --- a/pkg/github/tools_test.go +++ b/pkg/github/tools_test.go @@ -151,3 +151,34 @@ func TestContainsToolset(t *testing.T) { }) } } + +func TestGenerateToolsetsHelp(t *testing.T) { + // Generate the help text + helpText := GenerateToolsetsHelp() + + // Verify help text is not empty + require.NotEmpty(t, helpText) + + // Verify it contains expected sections + assert.Contains(t, helpText, "Comma-separated list of tool groups to enable") + assert.Contains(t, helpText, "Available:") + assert.Contains(t, helpText, "Special toolset keywords:") + assert.Contains(t, helpText, "all: Enables all available toolsets") + assert.Contains(t, helpText, "default: Enables the default toolset configuration") + assert.Contains(t, helpText, "Examples:") + assert.Contains(t, helpText, "--toolsets=actions,gists,notifications") + assert.Contains(t, helpText, "--toolsets=default,actions,gists") + assert.Contains(t, helpText, "--toolsets=all") + + // Verify it contains some expected default toolsets + assert.Contains(t, helpText, "context") + assert.Contains(t, helpText, "repos") + assert.Contains(t, helpText, "issues") + assert.Contains(t, helpText, "pull_requests") + assert.Contains(t, helpText, "users") + + // Verify it contains some expected available toolsets + assert.Contains(t, helpText, "actions") + assert.Contains(t, helpText, "gists") + assert.Contains(t, helpText, "notifications") +}