package auth_test import ( "context" "crypto/rand" "crypto/rsa" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/mathiasbq/supervisor/internal/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testKeys struct { priv jwk.Key pub jwk.Key } func generateRSAKeys(t *testing.T) testKeys { t.Helper() raw, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) priv, err := jwk.FromRaw(raw) require.NoError(t, err) require.NoError(t, priv.Set(jwk.KeyIDKey, "test-kid")) require.NoError(t, priv.Set(jwk.AlgorithmKey, jwa.RS256)) pub, err := jwk.PublicKeyOf(priv) require.NoError(t, err) return testKeys{priv: priv, pub: pub} } func mockOIDCServer(t *testing.T, keys testKeys) *httptest.Server { t.Helper() set := jwk.NewSet() require.NoError(t, set.AddKey(keys.pub)) jwksBytes, err := json.Marshal(set) require.NoError(t, err) mux := http.NewServeMux() var srv *httptest.Server mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]string{ "issuer": srv.URL, "jwks_uri": srv.URL + "/jwks", }) }) mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write(jwksBytes) }) srv = httptest.NewServer(mux) t.Cleanup(srv.Close) return srv } func signToken(t *testing.T, keys testKeys, issuer, audience, subject string, exp time.Time) string { t.Helper() b := jwt.NewBuilder(). Issuer(issuer). Subject(subject). Expiration(exp) if audience != "" { b = b.Audience([]string{audience}) } tok, err := b.Build() require.NoError(t, err) signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, keys.priv)) require.NoError(t, err) return string(signed) } func TestValidator(t *testing.T) { keys := generateRSAKeys(t) srv := mockOIDCServer(t, keys) ctx := context.Background() v, err := auth.NewValidator(srv.URL, "brain") require.NoError(t, err) tests := []struct { name string token string wantSub string wantErr bool }{ { name: "valid jwt", token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(time.Hour)), wantSub: "test-user", }, { name: "expired jwt", token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(-time.Hour)), wantErr: true, }, { name: "wrong issuer", token: signToken(t, keys, "https://evil.example.com", "brain", "test-user", time.Now().Add(time.Hour)), wantErr: true, }, { name: "wrong audience", token: signToken(t, keys, srv.URL, "other-service", "test-user", time.Now().Add(time.Hour)), wantErr: true, }, { name: "tampered token", token: signToken(t, keys, srv.URL, "brain", "test-user", time.Now().Add(time.Hour)) + "tampered", wantErr: true, }, { name: "not a jwt", token: "not-a-jwt", wantErr: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { sub, err := v.Validate(ctx, tc.token) if tc.wantErr { assert.Error(t, err) assert.Empty(t, sub) } else { require.NoError(t, err) assert.Equal(t, tc.wantSub, sub) } }) } } func TestNewValidator_NoAudience(t *testing.T) { keys := generateRSAKeys(t) srv := mockOIDCServer(t, keys) ctx := context.Background() v, err := auth.NewValidator(srv.URL, "") require.NoError(t, err) // Token without audience passes when audience validation is disabled. tok, err := jwt.NewBuilder(). Issuer(srv.URL). Subject("sub"). Expiration(time.Now().Add(time.Hour)). Build() require.NoError(t, err) signed, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, keys.priv)) require.NoError(t, err) sub, err := v.Validate(ctx, string(signed)) require.NoError(t, err) assert.Equal(t, "sub", sub) } func TestNewValidator_BadDiscoveryURL(t *testing.T) { _, err := auth.NewValidator("http://127.0.0.1:1", "brain") assert.Error(t, err) }