Skip to content

Commit ea3c8ab

Browse files
authored
Refactor WASM guard detection placement and inline config validation log wrappers (#5774)
Semantic clustering surfaced two low-risk refactors: WASM guard auto-detection lived in `cmd/proxy.go` instead of the WASM cache module, and `config/validation.go` used one-line private logging wrappers that obscured validation flow. This PR tightens cohesion in `internal/cmd` and simplifies config validation control flow without behavior changes. - **WASM setup cohesion (`internal/cmd`)** - Moved `containerGuardWasmPath` and `detectGuardWasm()` from `internal/cmd/proxy.go` to `internal/cmd/wasm_cache.go`. - Kept `newProxyCmd()` behavior unchanged (`defaultGuard := detectGuardWasm()`), but colocated guard-path detection with other WASM cache/bootstrap logic. - **Validation flow simplification (`internal/config`)** - Removed private wrappers: - `logValidateServerStart` - `logValidateServerPassed` - `logValidateServerFailed` - Inlined equivalent `logValidation.Printf(...)` calls at usage sites in server/auth/containerization validation paths to make control flow directly readable. - **Illustrative change** ```go // before logValidateServerFailed(name, server.Type, "HTTP server missing url field") // after logValidation.Printf("Validation failed: %s, name=%s, type=%s", "HTTP server missing url field", name, server.Type) ```
2 parents c9fd415 + 7b4ed07 commit ea3c8ab

3 files changed

Lines changed: 27 additions & 42 deletions

File tree

internal/cmd/proxy.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,6 @@ func init() {
5555
rootCmd.AddCommand(newProxyCmd())
5656
}
5757

58-
// containerGuardWasmPath is the baked-in guard path in the container image.
59-
const containerGuardWasmPath = "/guards/github/00-github-guard.wasm"
60-
61-
// detectGuardWasm returns the baked-in container guard path if it exists,
62-
// or empty string if not found (requiring the user to specify --guard-wasm).
63-
func detectGuardWasm() string {
64-
logProxyCmd.Printf("Checking for baked-in guard at %s", containerGuardWasmPath)
65-
if _, err := os.Stat(containerGuardWasmPath); err == nil {
66-
logProxyCmd.Printf("Auto-detected baked-in guard: %s", containerGuardWasmPath)
67-
return containerGuardWasmPath
68-
}
69-
logProxyCmd.Print("Baked-in guard not found, --guard-wasm flag required")
70-
return ""
71-
}
72-
7358
func newProxyCmd() *cobra.Command {
7459
defaultGuard := detectGuardWasm()
7560
defaultProxyLogDir := envutil.GetEnvString("MCP_GATEWAY_LOG_DIR", config.DefaultLogDir)

internal/cmd/wasm_cache.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ import (
1212
"github.com/github/gh-aw-mcpg/internal/guard"
1313
)
1414

15+
// containerGuardWasmPath is the baked-in guard path in the container image.
16+
const containerGuardWasmPath = "/guards/github/00-github-guard.wasm"
17+
18+
// detectGuardWasm returns the baked-in container guard path if it exists,
19+
// or empty string if not found (requiring the user to specify --guard-wasm).
20+
func detectGuardWasm() string {
21+
debugLog.Printf("Checking for baked-in guard at %s", containerGuardWasmPath)
22+
if _, err := os.Stat(containerGuardWasmPath); err == nil {
23+
debugLog.Printf("Auto-detected baked-in guard: %s", containerGuardWasmPath)
24+
return containerGuardWasmPath
25+
}
26+
debugLog.Print("Baked-in guard not found, --guard-wasm flag required")
27+
return ""
28+
}
29+
1530
func defaultWasmCacheDir(logDir string) string {
1631
if logDir == "" {
1732
return config.DefaultWasmCacheDirName

internal/config/validation.go

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,6 @@ type ValidationError = rules.ValidationError
1818

1919
var logValidation = logger.New("config:validation")
2020

21-
// logValidateServerStart logs the beginning of server config validation.
22-
func logValidateServerStart(name, serverType string) {
23-
logValidation.Printf("Validating server config: name=%s, type=%s", name, serverType)
24-
}
25-
26-
// logValidateServerPassed logs a successful server config validation.
27-
func logValidateServerPassed(name, serverType string) {
28-
logValidation.Printf("Server config validation passed: name=%s, type=%s", name, serverType)
29-
}
30-
31-
// logValidateServerFailed logs a failed server config validation with the given reason.
32-
func logValidateServerFailed(name, serverType, reason string) {
33-
logValidation.Printf("Validation failed: %s, name=%s, type=%s", reason, name, serverType)
34-
}
35-
3621
// validateMounts validates mount specifications using centralized rules
3722
func validateMounts(mounts []string, jsonPath string) error {
3823
for i, mount := range mounts {
@@ -45,7 +30,7 @@ func validateMounts(mounts []string, jsonPath string) error {
4530

4631
// validateServerConfigWithCustomSchemas validates a server configuration with custom schema support
4732
func validateServerConfigWithCustomSchemas(name string, server *StdinServerConfig, customSchemas map[string]interface{}) error {
48-
logValidateServerStart(name, server.Type)
33+
logValidation.Printf("Validating server config: name=%s, type=%s", name, server.Type)
4934
jsonPath := fmt.Sprintf("mcpServers.%s", name)
5035

5136
// Validate type (empty defaults to stdio)
@@ -74,7 +59,7 @@ func validateStandardServerConfig(name string, server *StdinServerConfig, jsonPa
7459
// For stdio servers, container is required
7560
if server.Type == "stdio" || server.Type == "local" {
7661
if server.Container == "" {
77-
logValidateServerFailed(name, server.Type, "stdio server missing container field")
62+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", "stdio server missing container field", name, server.Type)
7863
return rules.MissingRequired("container", "stdio", jsonPath, "Add a 'container' field (e.g., \"ghcr.io/owner/image:tag\")")
7964
}
8065

@@ -95,11 +80,11 @@ func validateStandardServerConfig(name string, server *StdinServerConfig, jsonPa
9580
// For HTTP servers, url is required and mounts are not allowed
9681
if server.Type == "http" {
9782
if server.URL == "" {
98-
logValidateServerFailed(name, server.Type, "HTTP server missing url field")
83+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", "HTTP server missing url field", name, server.Type)
9984
return rules.MissingRequired("url", "HTTP", jsonPath, "Add a 'url' field (e.g., \"https://example.com/mcp\")")
10085
}
10186
if len(server.Mounts) > 0 {
102-
logValidateServerFailed(name, server.Type, "HTTP server has mounts field")
87+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", "HTTP server has mounts field", name, server.Type)
10388
return rules.UnsupportedField("mounts", "mounts are only supported for stdio (containerized) servers", jsonPath, "Remove the 'mounts' field from HTTP server configuration; mounts only apply to stdio servers")
10489
}
10590

@@ -114,17 +99,17 @@ func validateStandardServerConfig(name string, server *StdinServerConfig, jsonPa
11499
if server.ToolTimeout != nil && *server.ToolTimeout != 0 {
115100
toolTimeoutField := server.toolTimeoutField()
116101
if err := rules.TimeoutMinimum(*server.ToolTimeout, ToolTimeoutMin, toolTimeoutField, jsonPath+"."+toolTimeoutField); err != nil {
117-
logValidateServerFailed(name, server.Type, fmt.Sprintf("%s %d is below minimum %d", toolTimeoutField, *server.ToolTimeout, ToolTimeoutMin))
102+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", fmt.Sprintf("%s %d is below minimum %d", toolTimeoutField, *server.ToolTimeout, ToolTimeoutMin), name, server.Type)
118103
return err
119104
}
120105
}
121106

122107
if err := validateToolResponseFilters(server.ToolResponseFilters, jsonPath+".tool_response_filters"); err != nil {
123-
logValidateServerFailed(name, server.Type, fmt.Sprintf("tool_response_filters invalid: %v", err))
108+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", fmt.Sprintf("tool_response_filters invalid: %v", err), name, server.Type)
124109
return err
125110
}
126111

127-
logValidateServerPassed(name, server.Type)
112+
logValidation.Printf("Server config validation passed: name=%s, type=%s", name, server.Type)
128113
return nil
129114
}
130115

@@ -163,7 +148,7 @@ func validateServerAuth(auth *AuthConfig, serverType, name, jsonPath string) err
163148
return nil
164149
}
165150
if serverType != "http" {
166-
logValidateServerFailed(name, serverType, fmt.Sprintf("auth is set on non-HTTP server type: %s", serverType))
151+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", fmt.Sprintf("auth is set on non-HTTP server type: %s", serverType), name, serverType)
167152
return rules.UnsupportedField(
168153
"auth",
169154
fmt.Sprintf("server type %q", serverType),
@@ -179,20 +164,20 @@ func validateAuthConfig(auth *AuthConfig, serverName, jsonPath string) error {
179164
logValidation.Printf("Validating auth config: server=%s, type=%s", serverName, auth.Type)
180165

181166
if auth.Type == "" {
182-
logValidateServerFailed(serverName, "http", "auth.type is empty")
167+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", "auth.type is empty", serverName, "http")
183168
return rules.MissingRequired("type", "auth", authPath, "Specify the authentication type (currently only \"github-oidc\" is supported)")
184169
}
185170

186171
if auth.Type != "github-oidc" {
187-
logValidateServerFailed(serverName, "http", fmt.Sprintf("unsupported auth.type: %s", auth.Type))
172+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", fmt.Sprintf("unsupported auth.type: %s", auth.Type), serverName, "http")
188173
return rules.UnsupportedType("type", auth.Type, authPath, fmt.Sprintf("Unsupported auth type %q. Currently only \"github-oidc\" is supported", auth.Type))
189174
}
190175

191176
// Fail-fast: check that required OIDC environment variables are present.
192177
// This catches misconfigurations at config-load time rather than deferring
193178
// the error to the first request against this server.
194179
if os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") == "" {
195-
logValidateServerFailed(serverName, "http", "ACTIONS_ID_TOKEN_REQUEST_URL is not set")
180+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", "ACTIONS_ID_TOKEN_REQUEST_URL is not set", serverName, "http")
196181
return rules.MissingRequired(
197182
"ACTIONS_ID_TOKEN_REQUEST_URL", "github-oidc", authPath,
198183
oidc.ErrMissingOIDCEnvVar(serverName).Error())
@@ -461,7 +446,7 @@ func validateTOMLStdioContainerization(servers map[string]*ServerConfig) error {
461446

462447
// Check if command is Docker
463448
if cfg.Command != "docker" {
464-
logValidateServerFailed(name, "stdio", fmt.Sprintf("stdio server using non-Docker command, command=%s", cfg.Command))
449+
logValidation.Printf("Validation failed: %s, name=%s, type=%s", fmt.Sprintf("stdio server using non-Docker command, command=%s", cfg.Command), name, "stdio")
465450
return fmt.Errorf(
466451
"server '%s': stdio servers must use containerized execution (command must be 'docker', got '%s'). "+
467452
"This is required by MCP Gateway Specification Section 3.2.1 (Containerization Requirement). "+

0 commit comments

Comments
 (0)