|
7 | 7 | "strings" |
8 | 8 |
|
9 | 9 | ghErrors "github.com/github/github-mcp-server/pkg/errors" |
| 10 | + "github.com/github/github-mcp-server/pkg/toolsets" |
10 | 11 | "github.com/github/github-mcp-server/pkg/translations" |
11 | 12 | "github.com/github/github-mcp-server/pkg/utils" |
12 | 13 | "github.com/google/go-github/v79/github" |
@@ -37,140 +38,139 @@ type TreeResponse struct { |
37 | 38 | } |
38 | 39 |
|
39 | 40 | // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. |
40 | | -func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { |
41 | | - tool := mcp.Tool{ |
42 | | - Name: "get_repository_tree", |
43 | | - Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), |
44 | | - Annotations: &mcp.ToolAnnotations{ |
45 | | - Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), |
46 | | - ReadOnlyHint: true, |
47 | | - }, |
48 | | - InputSchema: &jsonschema.Schema{ |
49 | | - Type: "object", |
50 | | - Properties: map[string]*jsonschema.Schema{ |
51 | | - "owner": { |
52 | | - Type: "string", |
53 | | - Description: "Repository owner (username or organization)", |
54 | | - }, |
55 | | - "repo": { |
56 | | - Type: "string", |
57 | | - Description: "Repository name", |
58 | | - }, |
59 | | - "tree_sha": { |
60 | | - Type: "string", |
61 | | - Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", |
62 | | - }, |
63 | | - "recursive": { |
64 | | - Type: "boolean", |
65 | | - Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", |
66 | | - Default: json.RawMessage(`false`), |
67 | | - }, |
68 | | - "path_filter": { |
69 | | - Type: "string", |
70 | | - Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", |
| 41 | +func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool { |
| 42 | + return NewTool( |
| 43 | + mcp.Tool{ |
| 44 | + Name: "get_repository_tree", |
| 45 | + Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), |
| 46 | + Annotations: &mcp.ToolAnnotations{ |
| 47 | + Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), |
| 48 | + ReadOnlyHint: true, |
| 49 | + }, |
| 50 | + InputSchema: &jsonschema.Schema{ |
| 51 | + Type: "object", |
| 52 | + Properties: map[string]*jsonschema.Schema{ |
| 53 | + "owner": { |
| 54 | + Type: "string", |
| 55 | + Description: "Repository owner (username or organization)", |
| 56 | + }, |
| 57 | + "repo": { |
| 58 | + Type: "string", |
| 59 | + Description: "Repository name", |
| 60 | + }, |
| 61 | + "tree_sha": { |
| 62 | + Type: "string", |
| 63 | + Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", |
| 64 | + }, |
| 65 | + "recursive": { |
| 66 | + Type: "boolean", |
| 67 | + Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", |
| 68 | + Default: json.RawMessage(`false`), |
| 69 | + }, |
| 70 | + "path_filter": { |
| 71 | + Type: "string", |
| 72 | + Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", |
| 73 | + }, |
71 | 74 | }, |
| 75 | + Required: []string{"owner", "repo"}, |
72 | 76 | }, |
73 | | - Required: []string{"owner", "repo"}, |
74 | 77 | }, |
75 | | - } |
| 78 | + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { |
| 79 | + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { |
| 80 | + owner, err := RequiredParam[string](args, "owner") |
| 81 | + if err != nil { |
| 82 | + return utils.NewToolResultError(err.Error()), nil, nil |
| 83 | + } |
| 84 | + repo, err := RequiredParam[string](args, "repo") |
| 85 | + if err != nil { |
| 86 | + return utils.NewToolResultError(err.Error()), nil, nil |
| 87 | + } |
| 88 | + treeSHA, err := OptionalParam[string](args, "tree_sha") |
| 89 | + if err != nil { |
| 90 | + return utils.NewToolResultError(err.Error()), nil, nil |
| 91 | + } |
| 92 | + recursive, err := OptionalBoolParamWithDefault(args, "recursive", false) |
| 93 | + if err != nil { |
| 94 | + return utils.NewToolResultError(err.Error()), nil, nil |
| 95 | + } |
| 96 | + pathFilter, err := OptionalParam[string](args, "path_filter") |
| 97 | + if err != nil { |
| 98 | + return utils.NewToolResultError(err.Error()), nil, nil |
| 99 | + } |
76 | 100 |
|
77 | | - handler := mcp.ToolHandlerFor[map[string]any, any]( |
78 | | - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { |
79 | | - owner, err := RequiredParam[string](args, "owner") |
80 | | - if err != nil { |
81 | | - return utils.NewToolResultError(err.Error()), nil, nil |
82 | | - } |
83 | | - repo, err := RequiredParam[string](args, "repo") |
84 | | - if err != nil { |
85 | | - return utils.NewToolResultError(err.Error()), nil, nil |
86 | | - } |
87 | | - treeSHA, err := OptionalParam[string](args, "tree_sha") |
88 | | - if err != nil { |
89 | | - return utils.NewToolResultError(err.Error()), nil, nil |
90 | | - } |
91 | | - recursive, err := OptionalBoolParamWithDefault(args, "recursive", false) |
92 | | - if err != nil { |
93 | | - return utils.NewToolResultError(err.Error()), nil, nil |
94 | | - } |
95 | | - pathFilter, err := OptionalParam[string](args, "path_filter") |
96 | | - if err != nil { |
97 | | - return utils.NewToolResultError(err.Error()), nil, nil |
98 | | - } |
| 101 | + client, err := deps.GetClient(ctx) |
| 102 | + if err != nil { |
| 103 | + return utils.NewToolResultError("failed to get GitHub client"), nil, nil |
| 104 | + } |
99 | 105 |
|
100 | | - client, err := getClient(ctx) |
101 | | - if err != nil { |
102 | | - return utils.NewToolResultError("failed to get GitHub client"), nil, nil |
103 | | - } |
| 106 | + // If no tree_sha is provided, use the repository's default branch |
| 107 | + if treeSHA == "" { |
| 108 | + repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo) |
| 109 | + if err != nil { |
| 110 | + return ghErrors.NewGitHubAPIErrorResponse(ctx, |
| 111 | + "failed to get repository info", |
| 112 | + repoResp, |
| 113 | + err, |
| 114 | + ), nil, nil |
| 115 | + } |
| 116 | + treeSHA = *repoInfo.DefaultBranch |
| 117 | + } |
104 | 118 |
|
105 | | - // If no tree_sha is provided, use the repository's default branch |
106 | | - if treeSHA == "" { |
107 | | - repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo) |
| 119 | + // Get the tree using the GitHub Git Tree API |
| 120 | + tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) |
108 | 121 | if err != nil { |
109 | 122 | return ghErrors.NewGitHubAPIErrorResponse(ctx, |
110 | | - "failed to get repository info", |
111 | | - repoResp, |
| 123 | + "failed to get repository tree", |
| 124 | + resp, |
112 | 125 | err, |
113 | 126 | ), nil, nil |
114 | 127 | } |
115 | | - treeSHA = *repoInfo.DefaultBranch |
116 | | - } |
| 128 | + defer func() { _ = resp.Body.Close() }() |
117 | 129 |
|
118 | | - // Get the tree using the GitHub Git Tree API |
119 | | - tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) |
120 | | - if err != nil { |
121 | | - return ghErrors.NewGitHubAPIErrorResponse(ctx, |
122 | | - "failed to get repository tree", |
123 | | - resp, |
124 | | - err, |
125 | | - ), nil, nil |
126 | | - } |
127 | | - defer func() { _ = resp.Body.Close() }() |
128 | | - |
129 | | - // Filter tree entries if path_filter is provided |
130 | | - var filteredEntries []*github.TreeEntry |
131 | | - if pathFilter != "" { |
132 | | - for _, entry := range tree.Entries { |
133 | | - if strings.HasPrefix(entry.GetPath(), pathFilter) { |
134 | | - filteredEntries = append(filteredEntries, entry) |
| 130 | + // Filter tree entries if path_filter is provided |
| 131 | + var filteredEntries []*github.TreeEntry |
| 132 | + if pathFilter != "" { |
| 133 | + for _, entry := range tree.Entries { |
| 134 | + if strings.HasPrefix(entry.GetPath(), pathFilter) { |
| 135 | + filteredEntries = append(filteredEntries, entry) |
| 136 | + } |
135 | 137 | } |
| 138 | + } else { |
| 139 | + filteredEntries = tree.Entries |
136 | 140 | } |
137 | | - } else { |
138 | | - filteredEntries = tree.Entries |
139 | | - } |
140 | 141 |
|
141 | | - treeEntries := make([]TreeEntryResponse, len(filteredEntries)) |
142 | | - for i, entry := range filteredEntries { |
143 | | - treeEntries[i] = TreeEntryResponse{ |
144 | | - Path: entry.GetPath(), |
145 | | - Type: entry.GetType(), |
146 | | - Mode: entry.GetMode(), |
147 | | - SHA: entry.GetSHA(), |
148 | | - URL: entry.GetURL(), |
| 142 | + treeEntries := make([]TreeEntryResponse, len(filteredEntries)) |
| 143 | + for i, entry := range filteredEntries { |
| 144 | + treeEntries[i] = TreeEntryResponse{ |
| 145 | + Path: entry.GetPath(), |
| 146 | + Type: entry.GetType(), |
| 147 | + Mode: entry.GetMode(), |
| 148 | + SHA: entry.GetSHA(), |
| 149 | + URL: entry.GetURL(), |
| 150 | + } |
| 151 | + if entry.Size != nil { |
| 152 | + treeEntries[i].Size = entry.Size |
| 153 | + } |
149 | 154 | } |
150 | | - if entry.Size != nil { |
151 | | - treeEntries[i].Size = entry.Size |
| 155 | + |
| 156 | + response := TreeResponse{ |
| 157 | + SHA: *tree.SHA, |
| 158 | + Truncated: *tree.Truncated, |
| 159 | + Tree: treeEntries, |
| 160 | + TreeSHA: treeSHA, |
| 161 | + Owner: owner, |
| 162 | + Repo: repo, |
| 163 | + Recursive: recursive, |
| 164 | + Count: len(filteredEntries), |
152 | 165 | } |
153 | | - } |
154 | 166 |
|
155 | | - response := TreeResponse{ |
156 | | - SHA: *tree.SHA, |
157 | | - Truncated: *tree.Truncated, |
158 | | - Tree: treeEntries, |
159 | | - TreeSHA: treeSHA, |
160 | | - Owner: owner, |
161 | | - Repo: repo, |
162 | | - Recursive: recursive, |
163 | | - Count: len(filteredEntries), |
164 | | - } |
| 167 | + r, err := json.Marshal(response) |
| 168 | + if err != nil { |
| 169 | + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) |
| 170 | + } |
165 | 171 |
|
166 | | - r, err := json.Marshal(response) |
167 | | - if err != nil { |
168 | | - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) |
| 172 | + return utils.NewToolResultText(string(r)), nil, nil |
169 | 173 | } |
170 | | - |
171 | | - return utils.NewToolResultText(string(r)), nil, nil |
172 | 174 | }, |
173 | 175 | ) |
174 | | - |
175 | | - return tool, handler |
176 | 176 | } |
0 commit comments