From c7e019248625f457c60bc129d72cf97c8c6fc97e Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Mon, 11 May 2026 20:10:05 +0200 Subject: [PATCH] feat(auth): add Dex JWT middleware to supervisor, routing pod, and brain MCP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #6 on gitea.d-ma.be/mathias/hyperguild. Dex is deployed at auth.d-ma.be. All three MCP servers now accept JWTs issued by Dex in addition to static bearer tokens, enabling claude.ai OAuth 2.0 integration without abandoning backward-compat CLI auth. Changes: - internal/auth/: new Validator (JWKS auto-refresh via lestrrat-go/jwx/v2), ProtectedResourceHandler (RFC 9728 /.well-known/oauth-protected-resource) - internal/mcp/Server: adds optional *auth.Validator; checkAuth tries JWT first, then static token fallback; both-nil = auth disabled (unchanged default) - cmd/supervisor, cmd/routing: construct Validator from DEX_ISSUER_URL + MCP_AUDIENCE env vars; register protected-resource handler when set - ingestion/internal/auth/: same Validator + handler (separate module) - ingestion/internal/mcp/BearerAuth: same JWT-or-static chain - ingestion/cmd/server: same wiring pattern New env vars (all optional; absent = static-token-only, same as before): DEX_ISSUER_URL — Dex issuer URL (e.g. https://auth.d-ma.be) MCP_AUDIENCE — expected aud claim (e.g. brain, supervisor) MCP_RESOURCE_URL — resource identifier for RFC 9728 metadata response Co-Authored-By: Claude Sonnet 4.6 --- cmd/routing/main.go | 21 ++- cmd/supervisor/main.go | 21 ++- go.mod | 17 +- go.sum | 27 +++ ingestion/cmd/server/main.go | 21 ++- ingestion/go.mod | 15 +- ingestion/go.sum | 28 +++ ingestion/internal/auth/jwt.go | 84 +++++++++ ingestion/internal/auth/jwt_test.go | 169 ++++++++++++++++++ ingestion/internal/auth/protected_resource.go | 23 +++ .../internal/auth/protected_resource_test.go | 28 +++ ingestion/internal/mcp/auth.go | 29 ++- ingestion/internal/mcp/auth_test.go | 127 +++++++++++-- internal/auth/jwt.go | 84 +++++++++ internal/auth/jwt_test.go | 169 ++++++++++++++++++ internal/auth/protected_resource.go | 23 +++ internal/auth/protected_resource_test.go | 28 +++ internal/mcp/server.go | 61 ++++--- internal/mcp/server_test.go | 12 +- 19 files changed, 934 insertions(+), 53 deletions(-) create mode 100644 ingestion/internal/auth/jwt.go create mode 100644 ingestion/internal/auth/jwt_test.go create mode 100644 ingestion/internal/auth/protected_resource.go create mode 100644 ingestion/internal/auth/protected_resource_test.go create mode 100644 internal/auth/jwt.go create mode 100644 internal/auth/jwt_test.go create mode 100644 internal/auth/protected_resource.go create mode 100644 internal/auth/protected_resource_test.go diff --git a/cmd/routing/main.go b/cmd/routing/main.go index 1c04ffa..79eea26 100644 --- a/cmd/routing/main.go +++ b/cmd/routing/main.go @@ -14,6 +14,7 @@ import ( "os" "time" + "github.com/mathiasbq/supervisor/internal/auth" "github.com/mathiasbq/supervisor/internal/config" iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/mcp" @@ -98,13 +99,31 @@ func main() { CompleteFunc: trainer.CompleteFunc(wrap("trainer")), })) - srv := mcp.NewServer(reg, cfg.MCPAuthToken) + var validator *auth.Validator + if dexURL := os.Getenv("DEX_ISSUER_URL"); dexURL != "" { + audience := os.Getenv("MCP_AUDIENCE") + v, err := auth.NewValidator(dexURL, audience) + if err != nil { + logger.Error("build jwt validator", "err", err) + os.Exit(1) + } + validator = v + logger.Info("jwt auth enabled", "issuer", dexURL) + } + + srv := mcp.NewServer(reg, cfg.MCPAuthToken, validator) mux := http.NewServeMux() mux.Handle("/mcp", srv) mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) + if dexURL := os.Getenv("DEX_ISSUER_URL"); dexURL != "" { + resourceURL := os.Getenv("MCP_RESOURCE_URL") + mux.HandleFunc("GET /.well-known/oauth-protected-resource", + auth.ProtectedResourceHandler(resourceURL, dexURL)) + } + addr := ":" + cfg.Port logger.Info("routing pod starting", "addr", addr, "fast", cfg.FastModel, "thinking", cfg.ThinkingModel, diff --git a/cmd/supervisor/main.go b/cmd/supervisor/main.go index 30a08bb..66232a7 100644 --- a/cmd/supervisor/main.go +++ b/cmd/supervisor/main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" + "github.com/mathiasbq/supervisor/internal/auth" "github.com/mathiasbq/supervisor/internal/config" iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/mcp" @@ -150,10 +151,28 @@ func main() { BrainDir: cfg.BrainDir, })) - srv := mcp.NewServer(reg, cfg.MCPAuthToken) + var validator *auth.Validator + if dexURL := os.Getenv("DEX_ISSUER_URL"); dexURL != "" { + audience := os.Getenv("MCP_AUDIENCE") + v, err := auth.NewValidator(dexURL, audience) + if err != nil { + logger.Error("build jwt validator", "err", err) + os.Exit(1) + } + validator = v + logger.Info("jwt auth enabled", "issuer", dexURL) + } + + srv := mcp.NewServer(reg, cfg.MCPAuthToken, validator) mux := http.NewServeMux() mux.Handle("/mcp", srv) + if dexURL := os.Getenv("DEX_ISSUER_URL"); dexURL != "" { + resourceURL := os.Getenv("MCP_RESOURCE_URL") + mux.HandleFunc("GET /.well-known/oauth-protected-resource", + auth.ProtectedResourceHandler(resourceURL, dexURL)) + } + addr := ":" + cfg.Port logger.Info("supervisor starting", "addr", addr, "version", "v0.5.0") if err := http.ListenAndServe(addr, mux); err != nil { diff --git a/go.mod b/go.mod index f233471..3f8b6ec 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,23 @@ module github.com/mathiasbq/supervisor go 1.26.1 -require github.com/stretchr/testify v1.11.1 +require ( + github.com/lestrrat-go/jwx/v2 v2.1.6 + github.com/stretchr/testify v1.11.1 + gopkg.in/yaml.v3 v3.0.1 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/lestrrat-go/blackmagic v1.0.3 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.6 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + github.com/segmentio/asm v1.2.0 // indirect + golang.org/x/crypto v0.32.0 // indirect + golang.org/x/sys v0.31.0 // indirect ) diff --git a/go.sum b/go.sum index c4c1710..d8ec42d 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,37 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs= +github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k= +github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA= +github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ingestion/cmd/server/main.go b/ingestion/cmd/server/main.go index b75ffc6..85f4435 100644 --- a/ingestion/cmd/server/main.go +++ b/ingestion/cmd/server/main.go @@ -11,6 +11,7 @@ import ( "time" "github.com/mathiasbq/hyperguild/ingestion/internal/api" + "github.com/mathiasbq/hyperguild/ingestion/internal/auth" "github.com/mathiasbq/hyperguild/ingestion/internal/llm" "github.com/mathiasbq/hyperguild/ingestion/internal/mcp" "github.com/mathiasbq/hyperguild/ingestion/internal/pipeline" @@ -80,7 +81,25 @@ func main() { mux.HandleFunc("POST /ingest-raw", h.IngestRaw) mux.HandleFunc("POST /backfill-refs", h.BackfillRefs) mux.HandleFunc("GET /pass-rate", h.PassRate) - mux.Handle("/mcp", mcp.BearerAuth(mcpToken, mcpSrv)) + var jwtValidator *auth.Validator + if dexURL := os.Getenv("DEX_ISSUER_URL"); dexURL != "" { + audience := os.Getenv("MCP_AUDIENCE") + v, err := auth.NewValidator(dexURL, audience) + if err != nil { + logger.Error("build jwt validator", "err", err) + os.Exit(1) + } + jwtValidator = v + logger.Info("jwt auth enabled", "issuer", dexURL) + } + + mux.Handle("/mcp", mcp.BearerAuth(mcpToken, jwtValidator, mcpSrv)) + + if dexURL := os.Getenv("DEX_ISSUER_URL"); dexURL != "" { + resourceURL := os.Getenv("MCP_RESOURCE_URL") + mux.HandleFunc("GET /.well-known/oauth-protected-resource", + auth.ProtectedResourceHandler(resourceURL, os.Getenv("DEX_ISSUER_URL"))) + } addr := ":" + port watchIntervalLog := "disabled" diff --git a/ingestion/go.mod b/ingestion/go.mod index c13d6a2..3ddb2a1 100644 --- a/ingestion/go.mod +++ b/ingestion/go.mod @@ -2,10 +2,23 @@ module github.com/mathiasbq/hyperguild/ingestion go 1.26.1 -require github.com/stretchr/testify v1.11.1 +require ( + github.com/lestrrat-go/jwx/v2 v2.1.6 + github.com/stretchr/testify v1.11.1 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/lestrrat-go/blackmagic v1.0.3 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.6 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.2.0 // indirect + golang.org/x/crypto v0.32.0 // indirect + golang.org/x/sys v0.31.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/ingestion/go.sum b/ingestion/go.sum index cc8b3f4..d8ec42d 100644 --- a/ingestion/go.sum +++ b/ingestion/go.sum @@ -1,9 +1,37 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs= +github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k= +github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA= +github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ingestion/internal/auth/jwt.go b/ingestion/internal/auth/jwt.go new file mode 100644 index 0000000..36af6ed --- /dev/null +++ b/ingestion/internal/auth/jwt.go @@ -0,0 +1,84 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +// Validator validates Bearer JWTs issued by a Dex (OIDC) authorization server. +// Audience is optional; leave empty to skip audience validation. +type Validator struct { + issuer string + audience string + jwksURI string + cache *jwk.Cache +} + +// NewValidator fetches the OIDC discovery document from issuerURL, extracts +// jwks_uri, seeds the JWKS cache, and returns a ready Validator. +// If DEX_ISSUER_URL is not set the caller should pass "" and skip construction. +func NewValidator(issuerURL, audience string) (*Validator, error) { + resp, err := http.Get(issuerURL + "/.well-known/openid-configuration") //nolint:noctx + if err != nil { + return nil, fmt.Errorf("fetch oidc discovery: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oidc discovery: status %d", resp.StatusCode) + } + + var doc struct { + JWKSURI string `json:"jwks_uri"` + } + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + return nil, fmt.Errorf("decode oidc discovery: %w", err) + } + if doc.JWKSURI == "" { + return nil, fmt.Errorf("oidc discovery: empty jwks_uri") + } + + ctx := context.Background() + cache := jwk.NewCache(ctx) + if err := cache.Register(doc.JWKSURI, jwk.WithMinRefreshInterval(time.Hour)); err != nil { + return nil, fmt.Errorf("register jwks cache: %w", err) + } + if _, err := cache.Refresh(ctx, doc.JWKSURI); err != nil { + return nil, fmt.Errorf("initial jwks fetch: %w", err) + } + + return &Validator{ + issuer: issuerURL, + audience: audience, + jwksURI: doc.JWKSURI, + cache: cache, + }, nil +} + +// Validate parses and validates rawToken. Returns the subject claim on success. +func (v *Validator) Validate(ctx context.Context, rawToken string) (string, error) { + keySet, err := v.cache.Get(ctx, v.jwksURI) + if err != nil { + return "", fmt.Errorf("get jwks: %w", err) + } + + opts := []jwt.ParseOption{ + jwt.WithKeySet(keySet), + jwt.WithValidate(true), + jwt.WithIssuer(v.issuer), + } + if v.audience != "" { + opts = append(opts, jwt.WithAudience(v.audience)) + } + + tok, err := jwt.ParseString(rawToken, opts...) + if err != nil { + return "", fmt.Errorf("validate jwt: %w", err) + } + return tok.Subject(), nil +} diff --git a/ingestion/internal/auth/jwt_test.go b/ingestion/internal/auth/jwt_test.go new file mode 100644 index 0000000..59e3d77 --- /dev/null +++ b/ingestion/internal/auth/jwt_test.go @@ -0,0 +1,169 @@ +package auth_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/mathiasbq/hyperguild/ingestion/internal/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testKeys struct { + priv jwk.Key + pub jwk.Key +} + +func generateRSAKeys(t *testing.T) testKeys { + t.Helper() + raw, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + priv, err := jwk.FromRaw(raw) + require.NoError(t, err) + require.NoError(t, priv.Set(jwk.KeyIDKey, "test-kid")) + require.NoError(t, priv.Set(jwk.AlgorithmKey, jwa.RS256)) + + pub, err := jwk.PublicKeyOf(priv) + require.NoError(t, err) + + return testKeys{priv: priv, pub: pub} +} + +func mockOIDCServer(t *testing.T, keys testKeys) *httptest.Server { + t.Helper() + set := jwk.NewSet() + require.NoError(t, set.AddKey(keys.pub)) + jwksBytes, err := json.Marshal(set) + require.NoError(t, err) + + mux := http.NewServeMux() + var srv *httptest.Server + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": srv.URL, + "jwks_uri": srv.URL + "/jwks", + }) + }) + mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jwksBytes) + }) + srv = httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +func signToken(t *testing.T, keys testKeys, issuer, audience, subject string, exp time.Time) string { + t.Helper() + b := jwt.NewBuilder(). + Issuer(issuer). + Subject(subject). + Expiration(exp) + if audience != "" { + b = b.Audience([]string{audience}) + } + tok, err := b.Build() + require.NoError(t, err) + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, keys.priv)) + require.NoError(t, err) + return string(signed) +} + +func TestValidator(t *testing.T) { + keys := generateRSAKeys(t) + srv := mockOIDCServer(t, keys) + ctx := context.Background() + + v, err := auth.NewValidator(srv.URL, "brain") + require.NoError(t, err) + + tests := []struct { + name string + token string + wantSub string + wantErr bool + }{ + { + name: "valid jwt", + token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(time.Hour)), + wantSub: "test-user", + }, + { + name: "expired jwt", + token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(-time.Hour)), + wantErr: true, + }, + { + name: "wrong issuer", + token: signToken(t, keys, "https://evil.example.com", "brain", "test-user", time.Now().Add(time.Hour)), + wantErr: true, + }, + { + name: "wrong audience", + token: signToken(t, keys, srv.URL, "other-service", "test-user", time.Now().Add(time.Hour)), + wantErr: true, + }, + { + name: "tampered token", + token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(time.Hour)) + "tampered", + wantErr: true, + }, + { + name: "not a jwt", + token: "not-a-jwt", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sub, err := v.Validate(ctx, tc.token) + if tc.wantErr { + assert.Error(t, err) + assert.Empty(t, sub) + } else { + require.NoError(t, err) + assert.Equal(t, tc.wantSub, sub) + } + }) + } +} + +func TestNewValidator_NoAudience(t *testing.T) { + keys := generateRSAKeys(t) + srv := mockOIDCServer(t, keys) + ctx := context.Background() + + v, err := auth.NewValidator(srv.URL, "") + require.NoError(t, err) + + // Token without audience passes when audience validation is disabled. + tok, err := jwt.NewBuilder(). + Issuer(srv.URL). + Subject("sub"). + Expiration(time.Now().Add(time.Hour)). + Build() + require.NoError(t, err) + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, keys.priv)) + require.NoError(t, err) + + sub, err := v.Validate(ctx, string(signed)) + require.NoError(t, err) + assert.Equal(t, "sub", sub) +} + +func TestNewValidator_BadDiscoveryURL(t *testing.T) { + _, err := auth.NewValidator("http://127.0.0.1:1", "brain") + assert.Error(t, err) +} diff --git a/ingestion/internal/auth/protected_resource.go b/ingestion/internal/auth/protected_resource.go new file mode 100644 index 0000000..fb86e23 --- /dev/null +++ b/ingestion/internal/auth/protected_resource.go @@ -0,0 +1,23 @@ +package auth + +import ( + "encoding/json" + "net/http" +) + +// ProtectedResourceHandler returns an RFC 9728 oauth-protected-resource metadata +// handler. Mount at GET /.well-known/oauth-protected-resource (no auth required). +func ProtectedResourceHandler(resourceURL, issuerURL string) http.HandlerFunc { + type metadata struct { + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + } + body, _ := json.Marshal(metadata{ + Resource: resourceURL, + AuthorizationServers: []string{issuerURL}, + }) + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(body) + } +} diff --git a/ingestion/internal/auth/protected_resource_test.go b/ingestion/internal/auth/protected_resource_test.go new file mode 100644 index 0000000..ba54ae0 --- /dev/null +++ b/ingestion/internal/auth/protected_resource_test.go @@ -0,0 +1,28 @@ +package auth_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mathiasbq/hyperguild/ingestion/internal/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProtectedResourceHandler(t *testing.T) { + h := auth.ProtectedResourceHandler("https://brain-mcp.d-ma.be", "https://auth.d-ma.be") + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-protected-resource", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + var body map[string]any + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + assert.Equal(t, "https://brain-mcp.d-ma.be", body["resource"]) + servers := body["authorization_servers"].([]any) + assert.Equal(t, "https://auth.d-ma.be", servers[0]) +} diff --git a/ingestion/internal/mcp/auth.go b/ingestion/internal/mcp/auth.go index 7509653..e86e087 100644 --- a/ingestion/internal/mcp/auth.go +++ b/ingestion/internal/mcp/auth.go @@ -1,23 +1,36 @@ package mcp import ( + "crypto/subtle" "net/http" "strings" + + "github.com/mathiasbq/hyperguild/ingestion/internal/auth" ) -// BearerAuth returns a middleware that enforces a static bearer token on every -// request. token must be non-empty; if it is empty, every request is rejected. -func BearerAuth(token string, next http.Handler) http.Handler { +// BearerAuth returns a middleware that enforces authentication on every request. +// It tries a valid Dex JWT first (when v is non-nil), then falls back to the +// static token. Rejects if token is empty and no valid JWT is presented. +func BearerAuth(token string, v *auth.Validator, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if token == "" { + rawToken, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") + if !ok { http.Error(w, "unauthorized", http.StatusUnauthorized) return } - got, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") - if !ok || got != token { - http.Error(w, "unauthorized", http.StatusUnauthorized) + + if v != nil { + if _, err := v.Validate(r.Context(), rawToken); err == nil { + next.ServeHTTP(w, r) + return + } + } + + if token != "" && subtle.ConstantTimeCompare([]byte(rawToken), []byte(token)) == 1 { + next.ServeHTTP(w, r) return } - next.ServeHTTP(w, r) + + http.Error(w, "unauthorized", http.StatusUnauthorized) }) } diff --git a/ingestion/internal/mcp/auth_test.go b/ingestion/internal/mcp/auth_test.go index 9aa8553..0462fca 100644 --- a/ingestion/internal/mcp/auth_test.go +++ b/ingestion/internal/mcp/auth_test.go @@ -1,18 +1,32 @@ package mcp_test import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" "net/http" "net/http/httptest" "testing" + "time" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/mathiasbq/hyperguild/ingestion/internal/auth" "github.com/mathiasbq/hyperguild/ingestion/internal/mcp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestBearerAuth_MissingHeader(t *testing.T) { - handler := mcp.BearerAuth("secret", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { +func okHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) - })) + }) +} + +func TestBearerAuth_MissingHeader(t *testing.T) { + handler := mcp.BearerAuth("secret", nil, okHandler()) req := httptest.NewRequest(http.MethodPost, "/mcp", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -20,9 +34,7 @@ func TestBearerAuth_MissingHeader(t *testing.T) { } func TestBearerAuth_WrongToken(t *testing.T) { - handler := mcp.BearerAuth("secret", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - })) + handler := mcp.BearerAuth("secret", nil, okHandler()) req := httptest.NewRequest(http.MethodPost, "/mcp", nil) req.Header.Set("Authorization", "Bearer wrong") rr := httptest.NewRecorder() @@ -32,7 +44,7 @@ func TestBearerAuth_WrongToken(t *testing.T) { func TestBearerAuth_CorrectToken(t *testing.T) { called := false - handler := mcp.BearerAuth("secret", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handler := mcp.BearerAuth("secret", nil, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) })) @@ -45,12 +57,105 @@ func TestBearerAuth_CorrectToken(t *testing.T) { } func TestBearerAuth_EmptyConfiguredToken(t *testing.T) { - // Server started without a token configured — every request must fail. - handler := mcp.BearerAuth("", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - })) + handler := mcp.BearerAuth("", nil, okHandler()) req := httptest.NewRequest(http.MethodPost, "/mcp", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) } + +// JWT auth tests + +func buildOIDCServer(t *testing.T) (*httptest.Server, jwk.Key) { + t.Helper() + raw, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + priv, err := jwk.FromRaw(raw) + require.NoError(t, err) + require.NoError(t, priv.Set(jwk.KeyIDKey, "k1")) + require.NoError(t, priv.Set(jwk.AlgorithmKey, jwa.RS256)) + pub, err := jwk.PublicKeyOf(priv) + require.NoError(t, err) + + set := jwk.NewSet() + require.NoError(t, set.AddKey(pub)) + jwksBytes, err := json.Marshal(set) + require.NoError(t, err) + + muxSrv := http.NewServeMux() + var srv *httptest.Server + muxSrv.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": srv.URL, + "jwks_uri": srv.URL + "/jwks", + }) + }) + muxSrv.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write(jwksBytes) + }) + srv = httptest.NewServer(muxSrv) + t.Cleanup(srv.Close) + return srv, priv +} + +func signJWT(t *testing.T, priv jwk.Key, issuer, audience string, exp time.Time) string { + t.Helper() + tok, err := jwt.NewBuilder(). + Issuer(issuer).Audience([]string{audience}). + Subject("s").Expiration(exp). + Build() + require.NoError(t, err) + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, priv)) + require.NoError(t, err) + return string(signed) +} + +func TestBearerAuth_ValidJWT(t *testing.T) { + oidcSrv, priv := buildOIDCServer(t) + v, err := auth.NewValidator(oidcSrv.URL, "brain") + require.NoError(t, err) + + called := false + handler := mcp.BearerAuth("static-secret", v, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + token := signJWT(t, priv, oidcSrv.URL, "brain", time.Now().Add(time.Hour)) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.True(t, called) +} + +func TestBearerAuth_InvalidJWT_FallsBackToStaticToken(t *testing.T) { + oidcSrv, _ := buildOIDCServer(t) + v, err := auth.NewValidator(oidcSrv.URL, "brain") + require.NoError(t, err) + + handler := mcp.BearerAuth("static-secret", v, okHandler()) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer static-secret") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestBearerAuth_InvalidJWT_WrongStaticToken(t *testing.T) { + oidcSrv, priv := buildOIDCServer(t) + v, err := auth.NewValidator(oidcSrv.URL, "brain") + require.NoError(t, err) + + handler := mcp.BearerAuth("static-secret", v, okHandler()) + // Expired JWT — JWT fails, static token doesn't match either + token := signJWT(t, priv, oidcSrv.URL, "brain", time.Now().Add(-time.Hour)) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer "+token) + + _ = context.Background() // satisfies import + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..36af6ed --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,84 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +// Validator validates Bearer JWTs issued by a Dex (OIDC) authorization server. +// Audience is optional; leave empty to skip audience validation. +type Validator struct { + issuer string + audience string + jwksURI string + cache *jwk.Cache +} + +// NewValidator fetches the OIDC discovery document from issuerURL, extracts +// jwks_uri, seeds the JWKS cache, and returns a ready Validator. +// If DEX_ISSUER_URL is not set the caller should pass "" and skip construction. +func NewValidator(issuerURL, audience string) (*Validator, error) { + resp, err := http.Get(issuerURL + "/.well-known/openid-configuration") //nolint:noctx + if err != nil { + return nil, fmt.Errorf("fetch oidc discovery: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oidc discovery: status %d", resp.StatusCode) + } + + var doc struct { + JWKSURI string `json:"jwks_uri"` + } + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + return nil, fmt.Errorf("decode oidc discovery: %w", err) + } + if doc.JWKSURI == "" { + return nil, fmt.Errorf("oidc discovery: empty jwks_uri") + } + + ctx := context.Background() + cache := jwk.NewCache(ctx) + if err := cache.Register(doc.JWKSURI, jwk.WithMinRefreshInterval(time.Hour)); err != nil { + return nil, fmt.Errorf("register jwks cache: %w", err) + } + if _, err := cache.Refresh(ctx, doc.JWKSURI); err != nil { + return nil, fmt.Errorf("initial jwks fetch: %w", err) + } + + return &Validator{ + issuer: issuerURL, + audience: audience, + jwksURI: doc.JWKSURI, + cache: cache, + }, nil +} + +// Validate parses and validates rawToken. Returns the subject claim on success. +func (v *Validator) Validate(ctx context.Context, rawToken string) (string, error) { + keySet, err := v.cache.Get(ctx, v.jwksURI) + if err != nil { + return "", fmt.Errorf("get jwks: %w", err) + } + + opts := []jwt.ParseOption{ + jwt.WithKeySet(keySet), + jwt.WithValidate(true), + jwt.WithIssuer(v.issuer), + } + if v.audience != "" { + opts = append(opts, jwt.WithAudience(v.audience)) + } + + tok, err := jwt.ParseString(rawToken, opts...) + if err != nil { + return "", fmt.Errorf("validate jwt: %w", err) + } + return tok.Subject(), nil +} diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go new file mode 100644 index 0000000..8eabdf3 --- /dev/null +++ b/internal/auth/jwt_test.go @@ -0,0 +1,169 @@ +package auth_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/mathiasbq/supervisor/internal/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testKeys struct { + priv jwk.Key + pub jwk.Key +} + +func generateRSAKeys(t *testing.T) testKeys { + t.Helper() + raw, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + priv, err := jwk.FromRaw(raw) + require.NoError(t, err) + require.NoError(t, priv.Set(jwk.KeyIDKey, "test-kid")) + require.NoError(t, priv.Set(jwk.AlgorithmKey, jwa.RS256)) + + pub, err := jwk.PublicKeyOf(priv) + require.NoError(t, err) + + return testKeys{priv: priv, pub: pub} +} + +func mockOIDCServer(t *testing.T, keys testKeys) *httptest.Server { + t.Helper() + set := jwk.NewSet() + require.NoError(t, set.AddKey(keys.pub)) + jwksBytes, err := json.Marshal(set) + require.NoError(t, err) + + mux := http.NewServeMux() + var srv *httptest.Server + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": srv.URL, + "jwks_uri": srv.URL + "/jwks", + }) + }) + mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jwksBytes) + }) + srv = httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +func signToken(t *testing.T, keys testKeys, issuer, audience, subject string, exp time.Time) string { + t.Helper() + b := jwt.NewBuilder(). + Issuer(issuer). + Subject(subject). + Expiration(exp) + if audience != "" { + b = b.Audience([]string{audience}) + } + tok, err := b.Build() + require.NoError(t, err) + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, keys.priv)) + require.NoError(t, err) + return string(signed) +} + +func TestValidator(t *testing.T) { + keys := generateRSAKeys(t) + srv := mockOIDCServer(t, keys) + ctx := context.Background() + + v, err := auth.NewValidator(srv.URL, "brain") + require.NoError(t, err) + + tests := []struct { + name string + token string + wantSub string + wantErr bool + }{ + { + name: "valid jwt", + token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(time.Hour)), + wantSub: "test-user", + }, + { + name: "expired jwt", + token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(-time.Hour)), + wantErr: true, + }, + { + name: "wrong issuer", + token: signToken(t, keys, "https://evil.example.com", "brain", "test-user", time.Now().Add(time.Hour)), + wantErr: true, + }, + { + name: "wrong audience", + token: signToken(t, keys, srv.URL, "other-service", "test-user", time.Now().Add(time.Hour)), + wantErr: true, + }, + { + name: "tampered token", + token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(time.Hour)) + "tampered", + wantErr: true, + }, + { + name: "not a jwt", + token: "not-a-jwt", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sub, err := v.Validate(ctx, tc.token) + if tc.wantErr { + assert.Error(t, err) + assert.Empty(t, sub) + } else { + require.NoError(t, err) + assert.Equal(t, tc.wantSub, sub) + } + }) + } +} + +func TestNewValidator_NoAudience(t *testing.T) { + keys := generateRSAKeys(t) + srv := mockOIDCServer(t, keys) + ctx := context.Background() + + v, err := auth.NewValidator(srv.URL, "") + require.NoError(t, err) + + // Token without audience passes when audience validation is disabled. + tok, err := jwt.NewBuilder(). + Issuer(srv.URL). + Subject("sub"). + Expiration(time.Now().Add(time.Hour)). + Build() + require.NoError(t, err) + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, keys.priv)) + require.NoError(t, err) + + sub, err := v.Validate(ctx, string(signed)) + require.NoError(t, err) + assert.Equal(t, "sub", sub) +} + +func TestNewValidator_BadDiscoveryURL(t *testing.T) { + _, err := auth.NewValidator("http://127.0.0.1:1", "brain") + assert.Error(t, err) +} diff --git a/internal/auth/protected_resource.go b/internal/auth/protected_resource.go new file mode 100644 index 0000000..fb86e23 --- /dev/null +++ b/internal/auth/protected_resource.go @@ -0,0 +1,23 @@ +package auth + +import ( + "encoding/json" + "net/http" +) + +// ProtectedResourceHandler returns an RFC 9728 oauth-protected-resource metadata +// handler. Mount at GET /.well-known/oauth-protected-resource (no auth required). +func ProtectedResourceHandler(resourceURL, issuerURL string) http.HandlerFunc { + type metadata struct { + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + } + body, _ := json.Marshal(metadata{ + Resource: resourceURL, + AuthorizationServers: []string{issuerURL}, + }) + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(body) + } +} diff --git a/internal/auth/protected_resource_test.go b/internal/auth/protected_resource_test.go new file mode 100644 index 0000000..cdb25a7 --- /dev/null +++ b/internal/auth/protected_resource_test.go @@ -0,0 +1,28 @@ +package auth_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mathiasbq/supervisor/internal/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProtectedResourceHandler(t *testing.T) { + h := auth.ProtectedResourceHandler("https://brain-mcp.d-ma.be", "https://auth.d-ma.be") + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-protected-resource", nil) + rr := httptest.NewRecorder() + h(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + + var body map[string]any + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + assert.Equal(t, "https://brain-mcp.d-ma.be", body["resource"]) + servers := body["authorization_servers"].([]any) + assert.Equal(t, "https://auth.d-ma.be", servers[0]) +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 3b21898..f34c6e0 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -8,6 +8,7 @@ import ( "net/http" "strings" + "github.com/mathiasbq/supervisor/internal/auth" "github.com/mathiasbq/supervisor/internal/registry" ) @@ -32,15 +33,16 @@ type rpcError struct { // Server is an HTTP handler implementing the MCP JSON-RPC protocol. type Server struct { - reg *registry.Registry - token string + reg *registry.Registry + token string + validator *auth.Validator } -// NewServer constructs an MCP HTTP handler. If token is non-empty, every -// request must carry "Authorization: Bearer " or it is rejected with -// HTTP 401 and JSON-RPC error -32001. Empty token disables auth (default). -func NewServer(reg *registry.Registry, token string) *Server { - return &Server{reg: reg, token: token} +// NewServer constructs an MCP HTTP handler. token is the static bearer token +// (empty disables static auth). validator is optional; when non-nil, a valid +// JWT from Dex is accepted in addition to the static token. +func NewServer(reg *registry.Registry, token string, validator *auth.Validator) *Server { + return &Server{reg: reg, token: token, validator: validator} } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -120,27 +122,42 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) } -// checkAuth verifies the bearer token when one is configured. Returns true if -// the request may proceed, false if it has been rejected (401 already written). +// checkAuth verifies the bearer token. Accepts a valid Dex JWT (when validator +// is configured) or the static token. Returns true if the request may proceed. +// When neither token nor validator is configured, auth is disabled (default). func (s *Server) checkAuth(w http.ResponseWriter, r *http.Request) bool { - if s.token == "" { + if s.token == "" && s.validator == nil { return true } - const prefix = "Bearer " - hdr := r.Header.Get("Authorization") - if !strings.HasPrefix(hdr, prefix) || - subtle.ConstantTimeCompare([]byte(hdr[len(prefix):]), []byte(s.token)) != 1 { - slog.Warn("mcp auth rejected", "remote", r.RemoteAddr, "method", r.Method) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - _ = json.NewEncoder(w).Encode(response{ - JSONRPC: "2.0", - Error: &rpcError{Code: -32001, Message: "unauthorized"}, - }) + rawToken, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") + if !ok { + s.rejectAuth(w, r) return false } - return true + + if s.validator != nil { + if _, err := s.validator.Validate(r.Context(), rawToken); err == nil { + return true + } + } + + if s.token != "" && subtle.ConstantTimeCompare([]byte(rawToken), []byte(s.token)) == 1 { + return true + } + + s.rejectAuth(w, r) + return false +} + +func (s *Server) rejectAuth(w http.ResponseWriter, r *http.Request) { + slog.Warn("mcp auth rejected", "remote", r.RemoteAddr, "method", r.Method) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(response{ + JSONRPC: "2.0", + Error: &rpcError{Code: -32001, Message: "unauthorized"}, + }) } func writeError(w http.ResponseWriter, id any, code int, msg string) { diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index a9fe15c..27dcde7 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -23,7 +23,7 @@ func jsonBody(t *testing.T, v any) *bytes.Buffer { func TestMCPInitialize(t *testing.T) { reg := registry.New() - srv := mcp.NewServer(reg, "") + srv := mcp.NewServer(reg, "", nil) req := httptest.NewRequest(http.MethodPost, "/mcp", jsonBody(t, map[string]any{ "jsonrpc": "2.0", @@ -45,7 +45,7 @@ func TestMCPInitialize(t *testing.T) { func TestMCPToolsList(t *testing.T) { reg := registry.New() - srv := mcp.NewServer(reg, "") + srv := mcp.NewServer(reg, "", nil) req := httptest.NewRequest(http.MethodPost, "/mcp", jsonBody(t, map[string]any{ "jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": map[string]any{}, @@ -63,7 +63,7 @@ func TestMCPToolsList(t *testing.T) { func TestMCPUnknownMethod(t *testing.T) { reg := registry.New() - srv := mcp.NewServer(reg, "") + srv := mcp.NewServer(reg, "", nil) req := httptest.NewRequest(http.MethodPost, "/mcp", jsonBody(t, map[string]any{ "jsonrpc": "2.0", "id": 3, "method": "unknown/method", "params": map[string]any{}, @@ -80,7 +80,7 @@ func TestMCPUnknownMethod(t *testing.T) { func TestMCPNotificationKnownMethodGetsNoResponseBody(t *testing.T) { reg := registry.New() - srv := mcp.NewServer(reg, "") + srv := mcp.NewServer(reg, "", nil) // JSON-RPC 2.0 notification: "id" field absent. Per spec, server MUST NOT // reply. notifications/initialized is part of the standard MCP handshake. @@ -116,7 +116,7 @@ func TestMCPAuth(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { reg := registry.New() - srv := mcp.NewServer(reg, tc.token) + srv := mcp.NewServer(reg, tc.token, nil) req := httptest.NewRequest(http.MethodPost, "/mcp", jsonBody(t, map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{}, @@ -142,7 +142,7 @@ func TestMCPAuth(t *testing.T) { func TestMCPNotificationUnknownMethodGetsNoResponseBody(t *testing.T) { reg := registry.New() - srv := mcp.NewServer(reg, "") + srv := mcp.NewServer(reg, "", nil) req := httptest.NewRequest(http.MethodPost, "/mcp", jsonBody(t, map[string]any{ "jsonrpc": "2.0",