From 699aa26b4191ac1743a1819c5c724bfdb16e6bee Mon Sep 17 00:00:00 2001 From: Ashwin Bhat Date: Wed, 4 Jun 2025 11:56:56 -0700 Subject: [PATCH] fix: only load GitHub MCP server when its tools are allowed (#124) * fix: only load GitHub MCP server when its tools are allowed - Add allowedTools parameter to prepareMcpConfig - Check for mcp__github__ and mcp__github_file_ops__ tool prefixes - Only include MCP servers when their tools are in allowed_tools - Maintain backward compatibility when allowed_tools is not specified - Update tests to reflect the new conditional loading behavior This optimizes resource usage by not loading unnecessary MCP servers when their tools are not allowed in the configuration. Co-authored-by: ashwin-ant * fix: always load github_file_ops server regardless of allowed_tools - Only apply conditional loading to the github MCP server - Always load github_file_ops server as it contains essential tools - Update tests to reflect this behavior Co-authored-by: ashwin-ant * refactor: move allowedTools/disallowedTools parsing to parseGitHubContext - Change allowedTools and disallowedTools from string to string[] in ParsedGitHubContext type - Parse comma-separated environment variables into arrays in parseGitHubContext function - Update create-prompt and install-mcp-server to use pre-parsed arrays - Update all affected test files to use array syntax - Eliminate duplicate parsing logic across the codebase * style: apply prettier formatting --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: ashwin-ant --- src/create-prompt/index.ts | 35 ++++++------ src/entrypoints/prepare.ts | 1 + src/github/context.ts | 14 +++-- src/mcp/install-mcp-server.ts | 41 ++++++++------ test/create-prompt.test.ts | 14 ++--- test/install-mcp-server.test.ts | 96 ++++++++++++++++++++++++++++----- test/mockContext.ts | 4 +- test/permissions.test.ts | 4 +- test/prepare-context.test.ts | 2 +- test/trigger-validation.test.ts | 20 +++---- 10 files changed, 159 insertions(+), 72 deletions(-) diff --git a/src/create-prompt/index.ts b/src/create-prompt/index.ts index e292f34..5c20928 100644 --- a/src/create-prompt/index.ts +++ b/src/create-prompt/index.ts @@ -35,38 +35,35 @@ const BASE_ALLOWED_TOOLS = [ ]; const DISALLOWED_TOOLS = ["WebSearch", "WebFetch"]; -export function buildAllowedToolsString(customAllowedTools?: string): string { +export function buildAllowedToolsString(customAllowedTools?: string[]): string { let baseTools = [...BASE_ALLOWED_TOOLS]; let allAllowedTools = baseTools.join(","); - if (customAllowedTools) { - allAllowedTools = `${allAllowedTools},${customAllowedTools}`; + if (customAllowedTools && customAllowedTools.length > 0) { + allAllowedTools = `${allAllowedTools},${customAllowedTools.join(",")}`; } return allAllowedTools; } export function buildDisallowedToolsString( - customDisallowedTools?: string, - allowedTools?: string, + customDisallowedTools?: string[], + allowedTools?: string[], ): string { let disallowedTools = [...DISALLOWED_TOOLS]; // If user has explicitly allowed some hardcoded disallowed tools, remove them from disallowed list - if (allowedTools) { - const allowedToolsArray = allowedTools - .split(",") - .map((tool) => tool.trim()); + if (allowedTools && allowedTools.length > 0) { disallowedTools = disallowedTools.filter( - (tool) => !allowedToolsArray.includes(tool), + (tool) => !allowedTools.includes(tool), ); } let allDisallowedTools = disallowedTools.join(","); - if (customDisallowedTools) { + if (customDisallowedTools && customDisallowedTools.length > 0) { if (allDisallowedTools) { - allDisallowedTools = `${allDisallowedTools},${customDisallowedTools}`; + allDisallowedTools = `${allDisallowedTools},${customDisallowedTools.join(",")}`; } else { - allDisallowedTools = customDisallowedTools; + allDisallowedTools = customDisallowedTools.join(","); } } return allDisallowedTools; @@ -120,8 +117,10 @@ export function prepareContext( triggerPhrase, ...(triggerUsername && { triggerUsername }), ...(customInstructions && { customInstructions }), - ...(allowedTools && { allowedTools }), - ...(disallowedTools && { disallowedTools }), + ...(allowedTools.length > 0 && { allowedTools: allowedTools.join(",") }), + ...(disallowedTools.length > 0 && { + disallowedTools: disallowedTools.join(","), + }), ...(directPrompt && { directPrompt }), ...(claudeBranch && { claudeBranch }), }; @@ -636,11 +635,11 @@ export async function createPrompt( // Set allowed tools const allAllowedTools = buildAllowedToolsString( - preparedContext.allowedTools, + context.inputs.allowedTools, ); const allDisallowedTools = buildDisallowedToolsString( - preparedContext.disallowedTools, - preparedContext.allowedTools, + context.inputs.disallowedTools, + context.inputs.allowedTools, ); core.exportVariable("ALLOWED_TOOLS", allAllowedTools); diff --git a/src/entrypoints/prepare.ts b/src/entrypoints/prepare.ts index 5736268..6b240d8 100644 --- a/src/entrypoints/prepare.ts +++ b/src/entrypoints/prepare.ts @@ -92,6 +92,7 @@ async function run() { branch: branchInfo.currentBranch, additionalMcpConfig, claudeCommentId: commentId.toString(), + allowedTools: context.inputs.allowedTools, }); core.setOutput("mcp_config", mcpConfig); } catch (error) { diff --git a/src/github/context.ts b/src/github/context.ts index 0fb7f65..1e19303 100644 --- a/src/github/context.ts +++ b/src/github/context.ts @@ -28,8 +28,8 @@ export type ParsedGitHubContext = { inputs: { triggerPhrase: string; assigneeTrigger: string; - allowedTools: string; - disallowedTools: string; + allowedTools: string[]; + disallowedTools: string[]; customInstructions: string; directPrompt: string; baseBranch?: string; @@ -52,8 +52,14 @@ export function parseGitHubContext(): ParsedGitHubContext { inputs: { triggerPhrase: process.env.TRIGGER_PHRASE ?? "@claude", assigneeTrigger: process.env.ASSIGNEE_TRIGGER ?? "", - allowedTools: process.env.ALLOWED_TOOLS ?? "", - disallowedTools: process.env.DISALLOWED_TOOLS ?? "", + allowedTools: (process.env.ALLOWED_TOOLS ?? "") + .split(",") + .map((tool) => tool.trim()) + .filter((tool) => tool.length > 0), + disallowedTools: (process.env.DISALLOWED_TOOLS ?? "") + .split(",") + .map((tool) => tool.trim()) + .filter((tool) => tool.length > 0), customInstructions: process.env.CUSTOM_INSTRUCTIONS ?? "", directPrompt: process.env.DIRECT_PROMPT ?? "", baseBranch: process.env.BASE_BRANCH, diff --git a/src/mcp/install-mcp-server.ts b/src/mcp/install-mcp-server.ts index e820097..0eba6af 100644 --- a/src/mcp/install-mcp-server.ts +++ b/src/mcp/install-mcp-server.ts @@ -7,6 +7,7 @@ type PrepareConfigParams = { branch: string; additionalMcpConfig?: string; claudeCommentId?: string; + allowedTools: string[]; }; export async function prepareMcpConfig( @@ -19,24 +20,17 @@ export async function prepareMcpConfig( branch, additionalMcpConfig, claudeCommentId, + allowedTools, } = params; try { - const baseMcpConfig = { + const allowedToolsList = allowedTools || []; + + const hasGitHubMcpTools = allowedToolsList.some((tool) => + tool.startsWith("mcp__github__"), + ); + + const baseMcpConfig: { mcpServers: Record } = { mcpServers: { - github: { - command: "docker", - args: [ - "run", - "-i", - "--rm", - "-e", - "GITHUB_PERSONAL_ACCESS_TOKEN", - "ghcr.io/github/github-mcp-server:sha-e9f748f", // https://github.com/github/github-mcp-server/releases/tag/v0.4.0 - ], - env: { - GITHUB_PERSONAL_ACCESS_TOKEN: githubToken, - }, - }, github_file_ops: { command: "bun", args: [ @@ -57,6 +51,23 @@ export async function prepareMcpConfig( }, }; + if (hasGitHubMcpTools) { + baseMcpConfig.mcpServers.github = { + command: "docker", + args: [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server:sha-e9f748f", // https://github.com/github/github-mcp-server/releases/tag/v0.4.0 + ], + env: { + GITHUB_PERSONAL_ACCESS_TOKEN: githubToken, + }, + }; + } + // Merge with additional MCP config if provided if (additionalMcpConfig && additionalMcpConfig.trim()) { try { diff --git a/test/create-prompt.test.ts b/test/create-prompt.test.ts index 617f0ac..65c5625 100644 --- a/test/create-prompt.test.ts +++ b/test/create-prompt.test.ts @@ -652,7 +652,7 @@ describe("buildAllowedToolsString", () => { }); test("should append custom tools when provided", () => { - const customTools = "Tool1,Tool2,Tool3"; + const customTools = ["Tool1", "Tool2", "Tool3"]; const result = buildAllowedToolsString(customTools); // Base tools should be present @@ -683,7 +683,7 @@ describe("buildDisallowedToolsString", () => { }); test("should append custom disallowed tools when provided", () => { - const customDisallowedTools = "BadTool1,BadTool2"; + const customDisallowedTools = ["BadTool1", "BadTool2"]; const result = buildDisallowedToolsString(customDisallowedTools); // Base disallowed tools should be present @@ -701,8 +701,8 @@ describe("buildDisallowedToolsString", () => { }); test("should remove hardcoded disallowed tools if they are in allowed tools", () => { - const customDisallowedTools = "BadTool1,BadTool2"; - const allowedTools = "WebSearch,SomeOtherTool"; + const customDisallowedTools = ["BadTool1", "BadTool2"]; + const allowedTools = ["WebSearch", "SomeOtherTool"]; const result = buildDisallowedToolsString( customDisallowedTools, allowedTools, @@ -720,7 +720,7 @@ describe("buildDisallowedToolsString", () => { }); test("should remove all hardcoded disallowed tools if they are all in allowed tools", () => { - const allowedTools = "WebSearch,WebFetch,SomeOtherTool"; + const allowedTools = ["WebSearch", "WebFetch", "SomeOtherTool"]; const result = buildDisallowedToolsString(undefined, allowedTools); // Both hardcoded disallowed tools should be removed @@ -732,8 +732,8 @@ describe("buildDisallowedToolsString", () => { }); test("should handle custom disallowed tools when all hardcoded tools are overridden", () => { - const customDisallowedTools = "BadTool1,BadTool2"; - const allowedTools = "WebSearch,WebFetch"; + const customDisallowedTools = ["BadTool1", "BadTool2"]; + const allowedTools = ["WebSearch", "WebFetch"]; const result = buildDisallowedToolsString( customDisallowedTools, allowedTools, diff --git a/test/install-mcp-server.test.ts b/test/install-mcp-server.test.ts index 3d2f02e..4dbb32d 100644 --- a/test/install-mcp-server.test.ts +++ b/test/install-mcp-server.test.ts @@ -24,21 +24,19 @@ describe("prepareMcpConfig", () => { processExitSpy.mockRestore(); }); - test("should return base config when no additional config is provided", async () => { + test("should return base config when no additional config is provided and no allowed_tools", async () => { const result = await prepareMcpConfig({ githubToken: "test-token", owner: "test-owner", repo: "test-repo", branch: "test-branch", + allowedTools: [], }); const parsed = JSON.parse(result); expect(parsed.mcpServers).toBeDefined(); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); - expect(parsed.mcpServers.github.env.GITHUB_PERSONAL_ACCESS_TOKEN).toBe( - "test-token", - ); expect(parsed.mcpServers.github_file_ops.env.GITHUB_TOKEN).toBe( "test-token", ); @@ -49,6 +47,60 @@ describe("prepareMcpConfig", () => { ); }); + test("should include github MCP server when mcp__github__ tools are allowed", async () => { + const result = await prepareMcpConfig({ + githubToken: "test-token", + owner: "test-owner", + repo: "test-repo", + branch: "test-branch", + allowedTools: [ + "mcp__github__create_issue", + "mcp__github_file_ops__commit_files", + ], + }); + + const parsed = JSON.parse(result); + expect(parsed.mcpServers).toBeDefined(); + expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github_file_ops).toBeDefined(); + expect(parsed.mcpServers.github.env.GITHUB_PERSONAL_ACCESS_TOKEN).toBe( + "test-token", + ); + }); + + test("should not include github MCP server when only file_ops tools are allowed", async () => { + const result = await prepareMcpConfig({ + githubToken: "test-token", + owner: "test-owner", + repo: "test-repo", + branch: "test-branch", + allowedTools: [ + "mcp__github_file_ops__commit_files", + "mcp__github_file_ops__update_claude_comment", + ], + }); + + const parsed = JSON.parse(result); + expect(parsed.mcpServers).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); + expect(parsed.mcpServers.github_file_ops).toBeDefined(); + }); + + test("should include file_ops server even when no GitHub tools are allowed", async () => { + const result = await prepareMcpConfig({ + githubToken: "test-token", + owner: "test-owner", + repo: "test-repo", + branch: "test-branch", + allowedTools: ["Edit", "Read", "Write"], + }); + + const parsed = JSON.parse(result); + expect(parsed.mcpServers).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); + expect(parsed.mcpServers.github_file_ops).toBeDefined(); + }); + test("should return base config when additional config is empty string", async () => { const result = await prepareMcpConfig({ githubToken: "test-token", @@ -56,11 +108,12 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: "", + allowedTools: [], }); const parsed = JSON.parse(result); expect(parsed.mcpServers).toBeDefined(); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); expect(consoleWarningSpy).not.toHaveBeenCalled(); }); @@ -72,11 +125,12 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: " \n\t ", + allowedTools: [], }); const parsed = JSON.parse(result); expect(parsed.mcpServers).toBeDefined(); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); expect(consoleWarningSpy).not.toHaveBeenCalled(); }); @@ -100,6 +154,10 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: additionalConfig, + allowedTools: [ + "mcp__github__create_issue", + "mcp__github_file_ops__commit_files", + ], }); const parsed = JSON.parse(result); @@ -133,6 +191,10 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: additionalConfig, + allowedTools: [ + "mcp__github__create_issue", + "mcp__github_file_ops__commit_files", + ], }); const parsed = JSON.parse(result); @@ -169,12 +231,13 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: additionalConfig, + allowedTools: [], }); const parsed = JSON.parse(result); expect(parsed.customProperty).toBe("custom-value"); expect(parsed.anotherProperty).toEqual({ nested: "value" }); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.custom_server).toBeDefined(); }); @@ -187,13 +250,14 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: invalidJson, + allowedTools: [], }); const parsed = JSON.parse(result); expect(consoleWarningSpy).toHaveBeenCalledWith( expect.stringContaining("Failed to parse additional MCP config:"), ); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); }); @@ -206,6 +270,7 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: nonObjectJson, + allowedTools: [], }); const parsed = JSON.parse(result); @@ -215,7 +280,7 @@ describe("prepareMcpConfig", () => { expect(consoleWarningSpy).toHaveBeenCalledWith( expect.stringContaining("MCP config must be a valid JSON object"), ); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); }); @@ -228,6 +293,7 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: nullJson, + allowedTools: [], }); const parsed = JSON.parse(result); @@ -237,7 +303,7 @@ describe("prepareMcpConfig", () => { expect(consoleWarningSpy).toHaveBeenCalledWith( expect.stringContaining("MCP config must be a valid JSON object"), ); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); }); @@ -250,6 +316,7 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: arrayJson, + allowedTools: [], }); const parsed = JSON.parse(result); @@ -258,7 +325,7 @@ describe("prepareMcpConfig", () => { expect(consoleInfoSpy).toHaveBeenCalledWith( "Merging additional MCP server configuration with built-in servers", ); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops).toBeDefined(); // The array will be spread into the config (0: 1, 1: 2, 2: 3) expect(parsed[0]).toBe(1); @@ -295,12 +362,13 @@ describe("prepareMcpConfig", () => { repo: "test-repo", branch: "test-branch", additionalMcpConfig: additionalConfig, + allowedTools: [], }); const parsed = JSON.parse(result); expect(parsed.mcpServers.server1).toBeDefined(); expect(parsed.mcpServers.server2).toBeDefined(); - expect(parsed.mcpServers.github).toBeDefined(); + expect(parsed.mcpServers.github).not.toBeDefined(); expect(parsed.mcpServers.github_file_ops.command).toBe("overridden"); expect(parsed.mcpServers.github_file_ops.env.CUSTOM).toBe("value"); expect(parsed.otherConfig.nested.deeply).toBe("value"); @@ -315,6 +383,7 @@ describe("prepareMcpConfig", () => { owner: "test-owner", repo: "test-repo", branch: "test-branch", + allowedTools: [], }); const parsed = JSON.parse(result); @@ -334,6 +403,7 @@ describe("prepareMcpConfig", () => { owner: "test-owner", repo: "test-repo", branch: "test-branch", + allowedTools: [], }); const parsed = JSON.parse(result); diff --git a/test/mockContext.ts b/test/mockContext.ts index c93acc1..692137c 100644 --- a/test/mockContext.ts +++ b/test/mockContext.ts @@ -11,8 +11,8 @@ const defaultInputs = { triggerPhrase: "/claude", assigneeTrigger: "", anthropicModel: "claude-3-7-sonnet-20250219", - allowedTools: "", - disallowedTools: "", + allowedTools: [] as string[], + disallowedTools: [] as string[], customInstructions: "", directPrompt: "", useBedrock: false, diff --git a/test/permissions.test.ts b/test/permissions.test.ts index 931d873..61e2ca9 100644 --- a/test/permissions.test.ts +++ b/test/permissions.test.ts @@ -62,8 +62,8 @@ describe("checkWritePermissions", () => { inputs: { triggerPhrase: "@claude", assigneeTrigger: "", - allowedTools: "", - disallowedTools: "", + allowedTools: [], + disallowedTools: [], customInstructions: "", directPrompt: "", }, diff --git a/test/prepare-context.test.ts b/test/prepare-context.test.ts index 5be89f0..7811c5b 100644 --- a/test/prepare-context.test.ts +++ b/test/prepare-context.test.ts @@ -242,7 +242,7 @@ describe("parseEnvVarsWithContext", () => { ...mockPullRequestCommentContext, inputs: { ...mockPullRequestCommentContext.inputs, - allowedTools: "Tool1,Tool2", + allowedTools: ["Tool1", "Tool2"], }, }); const result = prepareContext(contextWithAllowedTools, "12345"); diff --git a/test/trigger-validation.test.ts b/test/trigger-validation.test.ts index ae5d6d3..bbe40bd 100644 --- a/test/trigger-validation.test.ts +++ b/test/trigger-validation.test.ts @@ -30,8 +30,8 @@ describe("checkContainsTrigger", () => { triggerPhrase: "/claude", assigneeTrigger: "", directPrompt: "Fix the bug in the login form", - allowedTools: "", - disallowedTools: "", + allowedTools: [], + disallowedTools: [], customInstructions: "", }, }); @@ -56,8 +56,8 @@ describe("checkContainsTrigger", () => { triggerPhrase: "/claude", assigneeTrigger: "", directPrompt: "", - allowedTools: "", - disallowedTools: "", + allowedTools: [], + disallowedTools: [], customInstructions: "", }, }); @@ -228,8 +228,8 @@ describe("checkContainsTrigger", () => { triggerPhrase: "@claude", assigneeTrigger: "", directPrompt: "", - allowedTools: "", - disallowedTools: "", + allowedTools: [], + disallowedTools: [], customInstructions: "", }, }); @@ -255,8 +255,8 @@ describe("checkContainsTrigger", () => { triggerPhrase: "@claude", assigneeTrigger: "", directPrompt: "", - allowedTools: "", - disallowedTools: "", + allowedTools: [], + disallowedTools: [], customInstructions: "", }, }); @@ -282,8 +282,8 @@ describe("checkContainsTrigger", () => { triggerPhrase: "@claude", assigneeTrigger: "", directPrompt: "", - allowedTools: "", - disallowedTools: "", + allowedTools: [], + disallowedTools: [], customInstructions: "", }, });