Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions internal/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
graphQLBody, err = io.ReadAll(r.Body)
r.Body.Close()
if err != nil {
http.Error(w, "failed to read request body", http.StatusBadRequest)
httputil.WriteErrorResponse(w, http.StatusBadRequest, "bad_request", "failed to read request body")
return
}
Comment on lines +83 to 85

Expand Down Expand Up @@ -116,7 +116,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if match == nil {
// Unknown REST endpoint — fail closed: deny rather than risk leaking unfiltered data
logHandler.Printf("unknown REST endpoint %s, blocking request", rawPath)
http.Error(w, "access denied: unrecognized endpoint", http.StatusForbidden)
httputil.WriteErrorResponse(w, http.StatusForbidden, "forbidden", "access denied: unrecognized endpoint")
return
}
toolName = match.ToolName
Expand Down Expand Up @@ -152,7 +152,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
errMsg := "returning 503: proxy enforcement not configured (no --policy flag provided)"
logHandler.Print(errMsg)
logger.LogError("proxy", "%s", errMsg)
http.Error(w, "proxy enforcement not configured", http.StatusServiceUnavailable)
httputil.WriteErrorResponse(w, http.StatusServiceUnavailable, "service_unavailable", "proxy enforcement not configured")
return
}

Expand All @@ -166,7 +166,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
if err != nil {
logHandler.Printf("[DIFC] Phase 1 failed: %v", err)
// On labeling failure, fail closed to prevent enforcement bypass
http.Error(w, "resource labeling failed", http.StatusBadGateway)
httputil.WriteErrorResponse(w, http.StatusBadGateway, "bad_gateway", "resource labeling failed")
return
}

Expand Down Expand Up @@ -332,7 +332,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
} else {
filteredJSON, err := json.Marshal(finalData)
if err != nil {
http.Error(w, "failed to serialize filtered response", http.StatusInternalServerError)
httputil.WriteErrorResponse(w, http.StatusInternalServerError, "internal_error", "failed to serialize filtered response")
return
}
copyResponseHeaders(w, resp)
Expand Down Expand Up @@ -401,13 +401,13 @@ func (h *proxyHandler) forwardAndReadBody(
) (*http.Response, []byte) {
resp, err := h.server.forwardToGitHub(ctx, method, path, body, contentType, clientAuth)
if err != nil {
http.Error(w, "upstream request failed", http.StatusBadGateway)
httputil.WriteErrorResponse(w, http.StatusBadGateway, "bad_gateway", "upstream request failed")
return nil, nil
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "failed to read upstream response", http.StatusBadGateway)
httputil.WriteErrorResponse(w, http.StatusBadGateway, "bad_gateway", "failed to read upstream response")
return nil, nil
}
return resp, respBody
Expand Down
21 changes: 18 additions & 3 deletions internal/proxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ func mockUpstream(t *testing.T, status int, body interface{}) *httptest.Server {
}))
}

func assertJSONErrorResponse(t *testing.T, resp *http.Response, wantStatus int, wantCode, wantMessage string) {
t.Helper()

require.NotNil(t, resp)
assert.Equal(t, wantStatus, resp.StatusCode)
assert.Equal(t, "application/json", resp.Header.Get("Content-Type"))

var got struct {
Error string `json:"error"`
Message string `json:"message"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
assert.Equal(t, wantCode, got.Error)
assert.Equal(t, wantMessage, got.Message)
}

// ─── ServeHTTP: health check ─────────────────────────────────────────────────

func TestServeHTTP_HealthCheck(t *testing.T) {
Expand Down Expand Up @@ -112,8 +128,7 @@ func TestServeHTTP_UnknownRESTEndpointBlocked(t *testing.T) {
w := httptest.NewRecorder()
h.ServeHTTP(w, req)

assert.Equal(t, http.StatusForbidden, w.Code)
assert.Contains(t, w.Body.String(), "access denied")
assertJSONErrorResponse(t, w.Result(), http.StatusForbidden, "forbidden", "access denied: unrecognized endpoint")
}

// ─── ServeHTTP: /api/v3 GH-host prefix is stripped ───────────────────────────
Expand Down Expand Up @@ -456,7 +471,7 @@ func TestForwardAndReadBody_NetworkError(t *testing.T) {

assert.Nil(t, resp)
assert.Nil(t, body)
assert.Equal(t, http.StatusBadGateway, w.Code)
assertJSONErrorResponse(t, w.Result(), http.StatusBadGateway, "bad_gateway", "upstream request failed")
}

// ─── ServeHTTP: search query param is passed to args ─────────────────────────
Expand Down
Loading