diff --git a/ingestion/cmd/server/main.go b/ingestion/cmd/server/main.go index 01a0dd0..92f1398 100644 --- a/ingestion/cmd/server/main.go +++ b/ingestion/cmd/server/main.go @@ -57,6 +57,12 @@ func main() { mcpSrv := mcp.NewServer(brainDir, &pipelineCfg, llmClient.Complete) + mcpToken := os.Getenv("BRAIN_MCP_TOKEN") + if mcpToken == "" { + logger.Error("BRAIN_MCP_TOKEN not set") + os.Exit(1) + } + ctx := context.Background() if watchInterval > 0 { watcher.Start(ctx, watcher.Config{ @@ -74,7 +80,7 @@ func main() { mux.HandleFunc("POST /ingest-raw", h.IngestRaw) mux.HandleFunc("POST /backfill-refs", h.BackfillRefs) mux.HandleFunc("GET /pass-rate", h.PassRate) - mux.Handle("POST /mcp", mcpSrv) + mux.Handle("POST /mcp", mcp.BearerAuth(mcpToken, mcpSrv)) addr := ":" + port watchIntervalLog := "disabled" diff --git a/ingestion/internal/mcp/auth.go b/ingestion/internal/mcp/auth.go new file mode 100644 index 0000000..99b62f5 --- /dev/null +++ b/ingestion/internal/mcp/auth.go @@ -0,0 +1,23 @@ +package mcp + +import ( + "net/http" + "strings" +) + +// BearerAuth returns a middleware that enforces a static bearer token on every +// request. token must be non-empty; if it is empty, every request is rejected. +func BearerAuth(token string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if token == "" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + got := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if got != token { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/ingestion/internal/mcp/auth_test.go b/ingestion/internal/mcp/auth_test.go new file mode 100644 index 0000000..9aa8553 --- /dev/null +++ b/ingestion/internal/mcp/auth_test.go @@ -0,0 +1,56 @@ +package mcp_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/mathiasbq/hyperguild/ingestion/internal/mcp" + "github.com/stretchr/testify/assert" +) + +func TestBearerAuth_MissingHeader(t *testing.T) { + handler := mcp.BearerAuth("secret", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} + +func TestBearerAuth_WrongToken(t *testing.T) { + handler := mcp.BearerAuth("secret", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer wrong") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} + +func TestBearerAuth_CorrectToken(t *testing.T) { + called := false + handler := mcp.BearerAuth("secret", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer secret") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.True(t, called) +} + +func TestBearerAuth_EmptyConfiguredToken(t *testing.T) { + // Server started without a token configured — every request must fail. + handler := mcp.BearerAuth("", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) +}