🌐 AI搜索 & 代理 主页
Skip to content

Commit d29e73b

Browse files
refactor(git): migrate GetRepositoryTree to NewTool pattern
1 parent 3b7fa6d commit d29e73b

File tree

4 files changed

+126
-301
lines changed

4 files changed

+126
-301
lines changed

pkg/github/git.go

Lines changed: 114 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88

99
ghErrors "github.com/github/github-mcp-server/pkg/errors"
10+
"github.com/github/github-mcp-server/pkg/toolsets"
1011
"github.com/github/github-mcp-server/pkg/translations"
1112
"github.com/github/github-mcp-server/pkg/utils"
1213
"github.com/google/go-github/v79/github"
@@ -37,140 +38,139 @@ type TreeResponse struct {
3738
}
3839

3940
// GetRepositoryTree creates a tool to get the tree structure of a GitHub repository.
40-
func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
41-
tool := mcp.Tool{
42-
Name: "get_repository_tree",
43-
Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"),
44-
Annotations: &mcp.ToolAnnotations{
45-
Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"),
46-
ReadOnlyHint: true,
47-
},
48-
InputSchema: &jsonschema.Schema{
49-
Type: "object",
50-
Properties: map[string]*jsonschema.Schema{
51-
"owner": {
52-
Type: "string",
53-
Description: "Repository owner (username or organization)",
54-
},
55-
"repo": {
56-
Type: "string",
57-
Description: "Repository name",
58-
},
59-
"tree_sha": {
60-
Type: "string",
61-
Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch",
62-
},
63-
"recursive": {
64-
Type: "boolean",
65-
Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false",
66-
Default: json.RawMessage(`false`),
67-
},
68-
"path_filter": {
69-
Type: "string",
70-
Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)",
41+
func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool {
42+
return NewTool(
43+
mcp.Tool{
44+
Name: "get_repository_tree",
45+
Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"),
46+
Annotations: &mcp.ToolAnnotations{
47+
Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"),
48+
ReadOnlyHint: true,
49+
},
50+
InputSchema: &jsonschema.Schema{
51+
Type: "object",
52+
Properties: map[string]*jsonschema.Schema{
53+
"owner": {
54+
Type: "string",
55+
Description: "Repository owner (username or organization)",
56+
},
57+
"repo": {
58+
Type: "string",
59+
Description: "Repository name",
60+
},
61+
"tree_sha": {
62+
Type: "string",
63+
Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch",
64+
},
65+
"recursive": {
66+
Type: "boolean",
67+
Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false",
68+
Default: json.RawMessage(`false`),
69+
},
70+
"path_filter": {
71+
Type: "string",
72+
Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)",
73+
},
7174
},
75+
Required: []string{"owner", "repo"},
7276
},
73-
Required: []string{"owner", "repo"},
7477
},
75-
}
78+
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
79+
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
80+
owner, err := RequiredParam[string](args, "owner")
81+
if err != nil {
82+
return utils.NewToolResultError(err.Error()), nil, nil
83+
}
84+
repo, err := RequiredParam[string](args, "repo")
85+
if err != nil {
86+
return utils.NewToolResultError(err.Error()), nil, nil
87+
}
88+
treeSHA, err := OptionalParam[string](args, "tree_sha")
89+
if err != nil {
90+
return utils.NewToolResultError(err.Error()), nil, nil
91+
}
92+
recursive, err := OptionalBoolParamWithDefault(args, "recursive", false)
93+
if err != nil {
94+
return utils.NewToolResultError(err.Error()), nil, nil
95+
}
96+
pathFilter, err := OptionalParam[string](args, "path_filter")
97+
if err != nil {
98+
return utils.NewToolResultError(err.Error()), nil, nil
99+
}
76100

77-
handler := mcp.ToolHandlerFor[map[string]any, any](
78-
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
79-
owner, err := RequiredParam[string](args, "owner")
80-
if err != nil {
81-
return utils.NewToolResultError(err.Error()), nil, nil
82-
}
83-
repo, err := RequiredParam[string](args, "repo")
84-
if err != nil {
85-
return utils.NewToolResultError(err.Error()), nil, nil
86-
}
87-
treeSHA, err := OptionalParam[string](args, "tree_sha")
88-
if err != nil {
89-
return utils.NewToolResultError(err.Error()), nil, nil
90-
}
91-
recursive, err := OptionalBoolParamWithDefault(args, "recursive", false)
92-
if err != nil {
93-
return utils.NewToolResultError(err.Error()), nil, nil
94-
}
95-
pathFilter, err := OptionalParam[string](args, "path_filter")
96-
if err != nil {
97-
return utils.NewToolResultError(err.Error()), nil, nil
98-
}
101+
client, err := deps.GetClient(ctx)
102+
if err != nil {
103+
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
104+
}
99105

100-
client, err := getClient(ctx)
101-
if err != nil {
102-
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
103-
}
106+
// If no tree_sha is provided, use the repository's default branch
107+
if treeSHA == "" {
108+
repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo)
109+
if err != nil {
110+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
111+
"failed to get repository info",
112+
repoResp,
113+
err,
114+
), nil, nil
115+
}
116+
treeSHA = *repoInfo.DefaultBranch
117+
}
104118

105-
// If no tree_sha is provided, use the repository's default branch
106-
if treeSHA == "" {
107-
repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo)
119+
// Get the tree using the GitHub Git Tree API
120+
tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive)
108121
if err != nil {
109122
return ghErrors.NewGitHubAPIErrorResponse(ctx,
110-
"failed to get repository info",
111-
repoResp,
123+
"failed to get repository tree",
124+
resp,
112125
err,
113126
), nil, nil
114127
}
115-
treeSHA = *repoInfo.DefaultBranch
116-
}
128+
defer func() { _ = resp.Body.Close() }()
117129

118-
// Get the tree using the GitHub Git Tree API
119-
tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive)
120-
if err != nil {
121-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
122-
"failed to get repository tree",
123-
resp,
124-
err,
125-
), nil, nil
126-
}
127-
defer func() { _ = resp.Body.Close() }()
128-
129-
// Filter tree entries if path_filter is provided
130-
var filteredEntries []*github.TreeEntry
131-
if pathFilter != "" {
132-
for _, entry := range tree.Entries {
133-
if strings.HasPrefix(entry.GetPath(), pathFilter) {
134-
filteredEntries = append(filteredEntries, entry)
130+
// Filter tree entries if path_filter is provided
131+
var filteredEntries []*github.TreeEntry
132+
if pathFilter != "" {
133+
for _, entry := range tree.Entries {
134+
if strings.HasPrefix(entry.GetPath(), pathFilter) {
135+
filteredEntries = append(filteredEntries, entry)
136+
}
135137
}
138+
} else {
139+
filteredEntries = tree.Entries
136140
}
137-
} else {
138-
filteredEntries = tree.Entries
139-
}
140141

141-
treeEntries := make([]TreeEntryResponse, len(filteredEntries))
142-
for i, entry := range filteredEntries {
143-
treeEntries[i] = TreeEntryResponse{
144-
Path: entry.GetPath(),
145-
Type: entry.GetType(),
146-
Mode: entry.GetMode(),
147-
SHA: entry.GetSHA(),
148-
URL: entry.GetURL(),
142+
treeEntries := make([]TreeEntryResponse, len(filteredEntries))
143+
for i, entry := range filteredEntries {
144+
treeEntries[i] = TreeEntryResponse{
145+
Path: entry.GetPath(),
146+
Type: entry.GetType(),
147+
Mode: entry.GetMode(),
148+
SHA: entry.GetSHA(),
149+
URL: entry.GetURL(),
150+
}
151+
if entry.Size != nil {
152+
treeEntries[i].Size = entry.Size
153+
}
149154
}
150-
if entry.Size != nil {
151-
treeEntries[i].Size = entry.Size
155+
156+
response := TreeResponse{
157+
SHA: *tree.SHA,
158+
Truncated: *tree.Truncated,
159+
Tree: treeEntries,
160+
TreeSHA: treeSHA,
161+
Owner: owner,
162+
Repo: repo,
163+
Recursive: recursive,
164+
Count: len(filteredEntries),
152165
}
153-
}
154166

155-
response := TreeResponse{
156-
SHA: *tree.SHA,
157-
Truncated: *tree.Truncated,
158-
Tree: treeEntries,
159-
TreeSHA: treeSHA,
160-
Owner: owner,
161-
Repo: repo,
162-
Recursive: recursive,
163-
Count: len(filteredEntries),
164-
}
167+
r, err := json.Marshal(response)
168+
if err != nil {
169+
return nil, nil, fmt.Errorf("failed to marshal response: %w", err)
170+
}
165171

166-
r, err := json.Marshal(response)
167-
if err != nil {
168-
return nil, nil, fmt.Errorf("failed to marshal response: %w", err)
172+
return utils.NewToolResultText(string(r)), nil, nil
169173
}
170-
171-
return utils.NewToolResultText(string(r)), nil, nil
172174
},
173175
)
174-
175-
return tool, handler
176176
}

pkg/github/git_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@ import (
1818

1919
func Test_GetRepositoryTree(t *testing.T) {
2020
// Verify tool definition once
21-
mockClient := github.NewClient(nil)
22-
tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper)
23-
require.NoError(t, toolsnaps.Test(tool.Name, tool))
21+
toolDef := GetRepositoryTree(translations.NullTranslationHelper)
22+
require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool))
2423

25-
assert.Equal(t, "get_repository_tree", tool.Name)
26-
assert.NotEmpty(t, tool.Description)
24+
assert.Equal(t, "get_repository_tree", toolDef.Tool.Name)
25+
assert.NotEmpty(t, toolDef.Tool.Description)
2726

2827
// Type assert the InputSchema to access its properties
29-
inputSchema, ok := tool.InputSchema.(*jsonschema.Schema)
28+
inputSchema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema)
3029
require.True(t, ok, "expected InputSchema to be *jsonschema.Schema")
3130
assert.Contains(t, inputSchema.Properties, "owner")
3231
assert.Contains(t, inputSchema.Properties, "repo")
@@ -148,12 +147,16 @@ func Test_GetRepositoryTree(t *testing.T) {
148147

149148
for _, tc := range tests {
150149
t.Run(tc.name, func(t *testing.T) {
151-
_, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper)
150+
client := github.NewClient(tc.mockedClient)
151+
deps := ToolDependencies{
152+
GetClient: stubGetClientFn(client),
153+
}
154+
handler := toolDef.Handler(deps)
152155

153156
// Create the tool request
154157
request := createMCPRequest(tc.requestArgs)
155158

156-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
159+
result, err := handler(context.Background(), &request)
157160

158161
if tc.expectError {
159162
require.NoError(t, err)

0 commit comments

Comments
 (0)