fix(auth): require Bearer on /mcp regardless of DefaultToken
All checks were successful
CD / Lint / Test / Vet (push) Successful in 7s
CD / Build & Import (push) Successful in 12s
CD / Deploy via GitOps (push) Successful in 4s

Previously BearerMiddleware allowed requests with no Authorization
header to pass through whenever GITEA_MCP_DEFAULT_TOKEN was set. The
intent was "fall back to the service PAT for upstream Gitea calls,"
but the side effect was that anyone could hit /mcp anonymously and the
server would happily proxy requests as the service account.

Drop that path. Auth on /mcp now requires either:
  - a valid Dex-issued JWT, or
  - a Bearer matching GITEA_MCP_STATIC_TOKEN.

The Gitea service PAT (GITEA_MCP_DEFAULT_TOKEN) is no longer wired
into BearerMiddleware at all — it stays an upstream-client concern,
used by gitea.NewClient for outbound API calls only. This decouples
"can this caller invoke a tool" from "what credentials does the tool
use against Gitea".

Tests updated: drop the NoAuthHeader_WithDefault permissive case, add
NoAuthHeader_RejectsEvenWhenStaticConfigured to lock in the new
behavior.

Closes part of mathias/infra#2.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mathias
2026-05-12 14:44:38 +02:00
parent 9987522f1a
commit 3795800461
3 changed files with 43 additions and 51 deletions

View File

@@ -68,7 +68,7 @@ func main() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/mcp", mcp.OriginAllowlist(cfg.OriginAllowlist)( mux.Handle("/mcp", mcp.OriginAllowlist(cfg.OriginAllowlist)(
auth.BearerMiddleware(jwtValidator, cfg.StaticToken, cfg.DefaultToken, auth.BearerMiddleware(jwtValidator, cfg.StaticToken,
auth.CallerMiddleware(mcpSrv), auth.CallerMiddleware(mcpSrv),
), ),
)) ))

View File

@@ -6,24 +6,23 @@ import (
"strings" "strings"
) )
// BearerMiddleware authenticates requests via one of three paths (in order): // BearerMiddleware authenticates requests via the Authorization header.
// //
// 1. Bearer token is a valid JWT issued by the configured Dex OIDC server. // A request is allowed when:
// 2. Bearer token matches staticToken (constant-time compare).
// 3. No Authorization header and defaultToken is set — allow through; the
// Gitea client will use its service PAT for upstream calls.
// //
// Any other case returns 401. // 1. The Bearer token is a valid JWT issued by the configured Dex OIDC server, or
func BearerMiddleware(jwtValidator *JWTValidator, staticToken, defaultToken string, next http.Handler) http.Handler { // 2. The Bearer token matches staticToken (constant-time compare).
//
// Any other case — including missing or empty Authorization header — returns 401.
//
// The Gitea service PAT is intentionally NOT used to authenticate the caller:
// it is only used by the Gitea client for upstream API calls. Decoupling the
// two prevents the MCP endpoint from being reachable anonymously when a service
// PAT happens to be configured.
func BearerMiddleware(jwtValidator *JWTValidator, staticToken string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bearer, hasBearer := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") bearer, hasBearer := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
hasBearer = hasBearer && bearer != "" if !hasBearer || bearer == "" {
if !hasBearer {
if defaultToken != "" {
next.ServeHTTP(w, r)
return
}
http.Error(w, "unauthorized", http.StatusUnauthorized) http.Error(w, "unauthorized", http.StatusUnauthorized)
return return
} }

View File

@@ -10,17 +10,17 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// helper: BearerMiddleware with no JWT validator and no static token func okHandler(called *bool) http.Handler {
func noJWTMiddleware(defaultToken string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
return auth.BearerMiddleware(nil, "", defaultToken, next) if called != nil {
*called = true
}
w.WriteHeader(http.StatusOK)
})
} }
func TestBearerMiddleware_NoAuthHeader_NoDefault(t *testing.T) { func TestBearerMiddleware_NoAuthHeader(t *testing.T) {
srv := httptest.NewServer(noJWTMiddleware("", srv := httptest.NewServer(auth.BearerMiddleware(nil, "", okHandler(nil)))
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}),
))
defer srv.Close() defer srv.Close()
resp, err := http.Post(srv.URL+"/mcp", "application/json", nil) resp, err := http.Post(srv.URL+"/mcp", "application/json", nil)
@@ -29,32 +29,33 @@ func TestBearerMiddleware_NoAuthHeader_NoDefault(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} }
func TestBearerMiddleware_NoAuthHeader_WithDefault(t *testing.T) { func TestBearerMiddleware_NoAuthHeader_RejectsEvenWhenStaticConfigured(t *testing.T) {
called := false // A configured staticToken must not allow unauthenticated callers through.
srv := httptest.NewServer(noJWTMiddleware("default-pat", srv := httptest.NewServer(auth.BearerMiddleware(nil, "any-static", okHandler(nil)))
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}),
))
defer srv.Close() defer srv.Close()
resp, err := http.Post(srv.URL+"/mcp", "application/json", nil) resp, err := http.Post(srv.URL+"/mcp", "application/json", nil)
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.True(t, called) }
func TestBearerMiddleware_EmptyBearer(t *testing.T) {
srv := httptest.NewServer(auth.BearerMiddleware(nil, "static", okHandler(nil)))
defer srv.Close()
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
req.Header.Set("Authorization", "Bearer ")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} }
func TestBearerMiddleware_StaticToken_Valid(t *testing.T) { func TestBearerMiddleware_StaticToken_Valid(t *testing.T) {
const staticToken = "my-static-token" const staticToken = "my-static-token"
called := false called := false
srv := httptest.NewServer(auth.BearerMiddleware(nil, staticToken, "", srv := httptest.NewServer(auth.BearerMiddleware(nil, staticToken, okHandler(&called)))
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}),
))
defer srv.Close() defer srv.Close()
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
@@ -67,11 +68,7 @@ func TestBearerMiddleware_StaticToken_Valid(t *testing.T) {
} }
func TestBearerMiddleware_StaticToken_Invalid(t *testing.T) { func TestBearerMiddleware_StaticToken_Invalid(t *testing.T) {
srv := httptest.NewServer(auth.BearerMiddleware(nil, "correct-token", "", srv := httptest.NewServer(auth.BearerMiddleware(nil, "correct-token", okHandler(nil)))
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}),
))
defer srv.Close() defer srv.Close()
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
@@ -82,12 +79,8 @@ func TestBearerMiddleware_StaticToken_Invalid(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
} }
func TestBearerMiddleware_UnknownBearer_NoJWT(t *testing.T) { func TestBearerMiddleware_UnknownBearer_NoStatic_NoJWT(t *testing.T) {
srv := httptest.NewServer(noJWTMiddleware("", srv := httptest.NewServer(auth.BearerMiddleware(nil, "", okHandler(nil)))
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}),
))
defer srv.Close() defer srv.Close()
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil) req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)