package mcp_test import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "testing" "gitea.d-ma.be/mathias/gitea-mcp/internal/mcp" "gitea.d-ma.be/mathias/gitea-mcp/internal/registry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func newServer(t *testing.T) *mcp.Server { t.Helper() reg := registry.New() return mcp.NewServer(mcp.ServerOptions{ Registry: reg, Sessions: mcp.NewSessionStore(), }) } func postJSON(t *testing.T, srv http.Handler, body any, sessionID string) *httptest.ResponseRecorder { t.Helper() b, _ := json.Marshal(body) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(b)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") if sessionID != "" { req.Header.Set("Mcp-Session-Id", sessionID) } rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) return rr } func TestInitialize(t *testing.T) { srv := newServer(t) rr := postJSON(t, srv, map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{"protocolVersion": "2025-06-18"}, }, "") require.Equal(t, http.StatusOK, rr.Code) sid := rr.Header().Get("Mcp-Session-Id") assert.NotEmpty(t, sid) var resp map[string]any require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp)) result := resp["result"].(map[string]any) assert.Equal(t, mcp.ProtocolVersion, result["protocolVersion"]) si := result["serverInfo"].(map[string]any) assert.Equal(t, "gitea-mcp", si["name"]) } func TestPostWithoutSessionAccepted(t *testing.T) { // gitea-mcp tools are stateless single-shot; Mcp-Session-Id is advisory. // claude.ai's MCP transport proxy is observed to not propagate the // session header reliably, so non-initialize calls must work without it. srv := newServer(t) rr := postJSON(t, srv, map[string]any{ "jsonrpc": "2.0", "id": 2, "method": "tools/list", }, "") require.Equal(t, http.StatusOK, rr.Code) var resp map[string]any require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp)) result := resp["result"].(map[string]any) assert.Contains(t, result, "tools") } func TestServerWithOriginAllowlistRejectsBadOrigin(t *testing.T) { srv := mcp.OriginAllowlist([]string{"https://claude.ai"})(newServer(t)) body, _ := json.Marshal(map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{"protocolVersion": "2025-06-18"}, }) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Origin", "https://evil.example") rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) } func TestToolsListAfterInitialize(t *testing.T) { srv := newServer(t) init := postJSON(t, srv, map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{"protocolVersion": "2025-06-18"}, }, "") sid := init.Header().Get("Mcp-Session-Id") rr := postJSON(t, srv, map[string]any{ "jsonrpc": "2.0", "id": 2, "method": "tools/list", }, sid) require.Equal(t, http.StatusOK, rr.Code) var resp map[string]any require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp)) result := resp["result"].(map[string]any) assert.Contains(t, result, "tools") } func TestPostBodyTooLarge(t *testing.T) { srv := newServer(t) // 2 MiB of 'a' characters — exceeds the 1 MiB cap. payload := bytes.Repeat([]byte("a"), 2<<20) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(payload)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) assert.NotEqual(t, http.StatusOK, rr.Code, "oversized body must not return 200") assert.Equal(t, http.StatusBadRequest, rr.Code) } func TestHEADReturnsMCPProtocolVersionHeader(t *testing.T) { srv := newServer(t) req := httptest.NewRequest(http.MethodHead, "/mcp", nil) rr := httptest.NewRecorder() srv.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, mcp.ProtocolVersion, rr.Header().Get("MCP-Protocol-Version")) } func TestToolsCallToolNotFound(t *testing.T) { srv := newServer(t) // Initialize to get a session ID. init := postJSON(t, srv, map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{"protocolVersion": "2025-06-18"}, }, "") sid := init.Header().Get("Mcp-Session-Id") rr := postJSON(t, srv, map[string]any{ "jsonrpc": "2.0", "id": 2, "method": "tools/call", "params": map[string]any{"name": "nonexistent", "arguments": map[string]any{}}, }, sid) require.Equal(t, http.StatusOK, rr.Code) var resp map[string]any require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &resp)) rpcErr, ok := resp["error"].(map[string]any) require.True(t, ok, "expected error field in response") code := int(rpcErr["code"].(float64)) assert.Equal(t, -32002, code, "expected CodeNotFound (-32002) for missing tool") assert.NotEmpty(t, rpcErr["message"]) }