From ba5068648ba1452c93575c7ce4d92413ff03b3b3 Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Mon, 4 May 2026 20:58:08 +0200 Subject: [PATCH] refactor(mcp): compose origin allowlist as middleware, remove duplication Co-Authored-By: Claude Sonnet 4.6 --- cmd/gitea-mcp/main.go | 7 +++---- internal/mcp/server.go | 23 ++--------------------- internal/mcp/server_test.go | 21 ++++++++++++++++++--- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/cmd/gitea-mcp/main.go b/cmd/gitea-mcp/main.go index 1084555..774f28e 100644 --- a/cmd/gitea-mcp/main.go +++ b/cmd/gitea-mcp/main.go @@ -23,13 +23,12 @@ func main() { // Tool registration happens in Phase 6+; for now, registry is empty. mcpSrv := mcp.NewServer(mcp.ServerOptions{ - Registry: reg, - OriginAllowlist: cfg.OriginAllowlist, - Sessions: mcp.NewSessionStore(), + Registry: reg, + Sessions: mcp.NewSessionStore(), }) mux := http.NewServeMux() - mux.Handle("/mcp", mcpSrv) + mux.Handle("/mcp", mcp.OriginAllowlist(cfg.OriginAllowlist)(mcpSrv)) mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index e81c97e..3a37c26 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -11,9 +11,8 @@ import ( const ProtocolVersion = "2025-06-18" type ServerOptions struct { - Registry *registry.Registry - OriginAllowlist []string - Sessions *SessionStore + Registry *registry.Registry + Sessions *SessionStore } type Server struct { @@ -28,24 +27,6 @@ func NewServer(opts ServerOptions) *Server { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Origin allowlist (no-op when allowlist empty or Origin missing) - if len(s.opts.OriginAllowlist) > 0 { - origin := r.Header.Get("Origin") - if origin != "" { - ok := false - for _, a := range s.opts.OriginAllowlist { - if a == origin { - ok = true - break - } - } - if !ok { - http.Error(w, "origin not allowed", http.StatusForbidden) - return - } - } - } - switch r.Method { case http.MethodGet: s.handleGET(w, r) diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 78e4a8e..bdce59e 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -17,9 +17,8 @@ func newServer(t *testing.T) *mcp.Server { t.Helper() reg := registry.New() return mcp.NewServer(mcp.ServerOptions{ - Registry: reg, - OriginAllowlist: nil, - Sessions: mcp.NewSessionStore(), + Registry: reg, + Sessions: mcp.NewSessionStore(), }) } @@ -68,6 +67,22 @@ func TestPostWithoutSessionRejected(t *testing.T) { require.Equal(t, http.StatusBadRequest, rr.Code) } +func TestServerWithOriginAllowlistRejectsBadOrigin(t *testing.T) { + srv := mcp.OriginAllowlist([]string{"https://claude.ai"})(newServer(t)) + body, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{"protocolVersion": "2025-06-18"}, + }) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "https://evil.example") + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) +} + func TestToolsListAfterInitialize(t *testing.T) { srv := newServer(t) init := postJSON(t, srv, map[string]any{