feat(auth): JWT-or-static middleware + /.well-known/oauth-protected-resource (issue #5)
- internal/auth/jwt.go: JWTValidator via lestrrat-go/jwx/v2, JWKS auto-refresh - internal/auth/bearer.go: replace Gitea PAT validation with JWT->static->default chain - internal/gitea/client.go: always use service PAT; remove TokenFromContext lookup - internal/config/config.go: add DexIssuerURL, MCPAudience, MCPResourceURL, StaticToken - cmd/gitea-mcp/main.go: wire validator, fix /.well-known to return real AS list - bearer_test.go: rewrite for new API
This commit is contained in:
@@ -1,55 +1,43 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type tokenKey struct{}
|
||||
|
||||
// BearerMiddleware validates the incoming bearer token as a Gitea PAT by
|
||||
// calling GET /api/v1/user. The validated token is stored in context for
|
||||
// downstream use by the Gitea client.
|
||||
// BearerMiddleware authenticates requests via one of three paths (in order):
|
||||
//
|
||||
// defaultToken, if non-empty, is used when no Authorization header is present
|
||||
// (e.g. claude.ai connectors which do not inject Bearer tokens).
|
||||
func BearerMiddleware(giteaBaseURL, defaultToken string, next http.Handler) http.Handler {
|
||||
hc := &http.Client{Timeout: 5 * time.Second}
|
||||
// 1. Bearer token is a valid JWT issued by the configured Dex OIDC server.
|
||||
// 2. Bearer token matches staticToken (constant-time compare).
|
||||
// 3. No Authorization header and defaultToken is set — allow through; the
|
||||
// Gitea client will use its service PAT for upstream calls.
|
||||
//
|
||||
// Any other case returns 401.
|
||||
func BearerMiddleware(jwtValidator *JWTValidator, staticToken, defaultToken string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
if !ok || token == "" {
|
||||
if defaultToken == "" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
bearer, hasBearer := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
hasBearer = hasBearer && bearer != ""
|
||||
|
||||
if !hasBearer {
|
||||
if defaultToken != "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
token = defaultToken
|
||||
}
|
||||
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, giteaBaseURL+"/api/v1/user", nil)
|
||||
if err != nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "token "+token)
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
|
||||
if jwtValidator.Validate(r.Context(), bearer) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
ctx := context.WithValue(r.Context(), tokenKey{}, token)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
|
||||
if staticToken != "" && subtle.ConstantTimeCompare([]byte(bearer), []byte(staticToken)) == 1 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
// TokenFromContext returns the validated Gitea PAT stored by BearerMiddleware.
|
||||
func TokenFromContext(ctx context.Context) string {
|
||||
if v, ok := ctx.Value(tokenKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -10,8 +10,13 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBearerMiddleware_NoAuthHeader(t *testing.T) {
|
||||
srv := httptest.NewServer(auth.BearerMiddleware("https://gitea.example.com", "",
|
||||
// helper: BearerMiddleware with no JWT validator and no static token
|
||||
func noJWTMiddleware(defaultToken string, next http.Handler) http.Handler {
|
||||
return auth.BearerMiddleware(nil, "", defaultToken, next)
|
||||
}
|
||||
|
||||
func TestBearerMiddleware_NoAuthHeader_NoDefault(t *testing.T) {
|
||||
srv := httptest.NewServer(noJWTMiddleware("",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
@@ -24,20 +29,11 @@ func TestBearerMiddleware_NoAuthHeader(t *testing.T) {
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestBearerMiddleware_NoAuthHeaderWithDefault(t *testing.T) {
|
||||
const defaultToken = "default-pat"
|
||||
|
||||
giteaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "token "+defaultToken, r.Header.Get("Authorization"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer giteaMock.Close()
|
||||
|
||||
func TestBearerMiddleware_NoAuthHeader_WithDefault(t *testing.T) {
|
||||
called := false
|
||||
srv := httptest.NewServer(auth.BearerMiddleware(giteaMock.URL, defaultToken,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
srv := httptest.NewServer(noJWTMiddleware("default-pat",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
assert.Equal(t, defaultToken, auth.TokenFromContext(r.Context()))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
))
|
||||
@@ -50,51 +46,19 @@ func TestBearerMiddleware_NoAuthHeaderWithDefault(t *testing.T) {
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestBearerMiddleware_InvalidToken(t *testing.T) {
|
||||
// Mock Gitea that rejects the token
|
||||
giteaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer giteaMock.Close()
|
||||
|
||||
srv := httptest.NewServer(auth.BearerMiddleware(giteaMock.URL, "",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
))
|
||||
defer srv.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
|
||||
req.Header.Set("Authorization", "Bearer bad-token")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestBearerMiddleware_ValidToken(t *testing.T) {
|
||||
const token = "valid-pat"
|
||||
|
||||
// Mock Gitea that accepts the token and returns a user
|
||||
giteaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "token "+token, r.Header.Get("Authorization"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer giteaMock.Close()
|
||||
|
||||
func TestBearerMiddleware_StaticToken_Valid(t *testing.T) {
|
||||
const staticToken = "my-static-token"
|
||||
called := false
|
||||
srv := httptest.NewServer(auth.BearerMiddleware(giteaMock.URL, "",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
srv := httptest.NewServer(auth.BearerMiddleware(nil, staticToken, "",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
// Token must be available in context for downstream Gitea client
|
||||
assert.Equal(t, token, auth.TokenFromContext(r.Context()))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
))
|
||||
defer srv.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Authorization", "Bearer "+staticToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
@@ -102,7 +66,34 @@ func TestBearerMiddleware_ValidToken(t *testing.T) {
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestTokenFromContext_Empty(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.Equal(t, "", auth.TokenFromContext(req.Context()))
|
||||
func TestBearerMiddleware_StaticToken_Invalid(t *testing.T) {
|
||||
srv := httptest.NewServer(auth.BearerMiddleware(nil, "correct-token", "",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
))
|
||||
defer srv.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
|
||||
req.Header.Set("Authorization", "Bearer wrong-token")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestBearerMiddleware_UnknownBearer_NoJWT(t *testing.T) {
|
||||
srv := httptest.NewServer(noJWTMiddleware("",
|
||||
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
))
|
||||
defer srv.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL+"/mcp", nil)
|
||||
req.Header.Set("Authorization", "Bearer random-unknown-token")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
79
internal/auth/jwt.go
Normal file
79
internal/auth/jwt.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
// JWTValidator validates bearer tokens as JWTs issued by a Dex OIDC server.
|
||||
// A nil JWTValidator always returns false — JWT validation is disabled.
|
||||
type JWTValidator struct {
|
||||
issuer string
|
||||
aud string
|
||||
cache *jwk.Cache
|
||||
jwksURI string
|
||||
}
|
||||
|
||||
// NewJWTValidator creates a validator by fetching the OIDC discovery document
|
||||
// from issuerURL. Returns nil, nil when issuerURL is empty (disabled).
|
||||
func NewJWTValidator(ctx context.Context, issuerURL, audience string) (*JWTValidator, error) {
|
||||
if issuerURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
resp, err := http.Get(issuerURL + "/.well-known/openid-configuration")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch oidc discovery: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
var doc struct {
|
||||
JWKSURI string `json:"jwks_uri"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
||||
return nil, fmt.Errorf("decode oidc discovery: %w", err)
|
||||
}
|
||||
|
||||
cache := jwk.NewCache(ctx)
|
||||
if err := cache.Register(doc.JWKSURI, jwk.WithRefreshInterval(time.Hour)); err != nil {
|
||||
return nil, fmt.Errorf("register jwks uri: %w", err)
|
||||
}
|
||||
// warm the cache immediately so first request doesn't block
|
||||
if _, err := cache.Refresh(ctx, doc.JWKSURI); err != nil {
|
||||
return nil, fmt.Errorf("warm jwks cache: %w", err)
|
||||
}
|
||||
|
||||
return &JWTValidator{
|
||||
issuer: issuerURL,
|
||||
aud: audience,
|
||||
cache: cache,
|
||||
jwksURI: doc.JWKSURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate returns true if rawToken is a valid JWT signed by the OIDC server.
|
||||
func (v *JWTValidator) Validate(ctx context.Context, rawToken string) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
keySet, err := v.cache.Get(ctx, v.jwksURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
opts := []jwt.ParseOption{
|
||||
jwt.WithKeySet(keySet),
|
||||
jwt.WithIssuer(v.issuer),
|
||||
jwt.WithValidate(true),
|
||||
}
|
||||
if v.aud != "" {
|
||||
opts = append(opts, jwt.WithAudience(v.aud))
|
||||
}
|
||||
_, err = jwt.Parse([]byte(rawToken), opts...)
|
||||
return err == nil
|
||||
}
|
||||
Reference in New Issue
Block a user