package auth import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" ) // JWTValidator validates bearer tokens as JWTs issued by a Dex OIDC server. // A nil JWTValidator always returns false — JWT validation is disabled. type JWTValidator struct { issuer string aud string cache *jwk.Cache jwksURI string } // NewJWTValidator creates a validator by fetching the OIDC discovery document // from issuerURL. Returns nil, nil when issuerURL is empty (disabled). func NewJWTValidator(ctx context.Context, issuerURL, audience string) (*JWTValidator, error) { if issuerURL == "" { return nil, nil } resp, err := http.Get(issuerURL + "/.well-known/openid-configuration") if err != nil { return nil, fmt.Errorf("fetch oidc discovery: %w", err) } defer func() { _ = resp.Body.Close() }() var doc struct { JWKSURI string `json:"jwks_uri"` } if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { return nil, fmt.Errorf("decode oidc discovery: %w", err) } cache := jwk.NewCache(ctx) if err := cache.Register(doc.JWKSURI, jwk.WithRefreshInterval(time.Hour)); err != nil { return nil, fmt.Errorf("register jwks uri: %w", err) } // warm the cache immediately so first request doesn't block if _, err := cache.Refresh(ctx, doc.JWKSURI); err != nil { return nil, fmt.Errorf("warm jwks cache: %w", err) } return &JWTValidator{ issuer: issuerURL, aud: audience, cache: cache, jwksURI: doc.JWKSURI, }, nil } // Validate returns true if rawToken is a valid JWT signed by the OIDC server. func (v *JWTValidator) Validate(ctx context.Context, rawToken string) bool { if v == nil { return false } keySet, err := v.cache.Get(ctx, v.jwksURI) if err != nil { return false } opts := []jwt.ParseOption{ jwt.WithKeySet(keySet), jwt.WithIssuer(v.issuer), jwt.WithValidate(true), } if v.aud != "" { opts = append(opts, jwt.WithAudience(v.aud)) } _, err = jwt.Parse([]byte(rawToken), opts...) return err == nil }