diff --git a/internal/tools/code_search.go b/internal/tools/code_search.go index 39c3795..7bee102 100644 --- a/internal/tools/code_search.go +++ b/internal/tools/code_search.go @@ -4,12 +4,21 @@ import ( "context" "encoding/json" "fmt" + "sort" + "sync" + "time" "gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist" "gitea.d-ma.be/mathias/gitea-mcp/internal/gitea" "gitea.d-ma.be/mathias/gitea-mcp/internal/registry" ) +type semaphore chan struct{} + +func newSem(n int) semaphore { return make(semaphore, n) } +func (s semaphore) acquire() { s <- struct{}{} } +func (s semaphore) release() { <-s } + type CodeSearch struct { c *gitea.Client a *allowlist.Allowlist @@ -71,11 +80,13 @@ func (t *CodeSearch) Call(ctx context.Context, raw json.RawMessage) (json.RawMes args.Limit = 30 } - if args.Repo == "" { - // Phase 7.2: leave fan-out unimplemented — just error out for now. - return nil, fmt.Errorf("repo is required for single-repo search (org-wide fan-out lands in 7.3): %w", gitea.ErrValidation) + if args.Repo != "" { + return t.singleRepo(ctx, args) } + return t.fanOut(ctx, args) +} +func (t *CodeSearch) singleRepo(ctx context.Context, args codeSearchArgs) (json.RawMessage, error) { hits, err := t.c.SearchCode(ctx, args.Owner, args.Repo, args.Q, args.Page, args.Limit) if err != nil { return nil, err @@ -102,3 +113,79 @@ func (t *CodeSearch) Call(ctx context.Context, raw json.RawMessage) (json.RawMes } return textOK(out) } + +func (t *CodeSearch) fanOut(ctx context.Context, args codeSearchArgs) (json.RawMessage, error) { + repos, err := t.c.ListRepos(ctx, args.Owner, 1, 50) + if err != nil { + return nil, err + } + + type repoResult struct { + repo string + hits []gitea.CodeSearchHit + err error + } + resultsCh := make(chan repoResult, len(repos)) + sem := newSem(5) + var wg sync.WaitGroup + + for _, r := range repos { + repo := r // capture + wg.Add(1) + go func() { + defer wg.Done() + sem.acquire() + defer sem.release() + + rctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + hits, err := t.c.SearchCode(rctx, args.Owner, repo.Name, args.Q, 1, args.Limit) + resultsCh <- repoResult{repo: args.Owner + "/" + repo.Name, hits: hits, err: err} + }() + } + wg.Wait() + close(resultsCh) + + merged := make([]codeSearchResult, 0) + var partialRepos []string + for rr := range resultsCh { + if rr.err != nil { + partialRepos = append(partialRepos, rr.repo) + continue + } + for _, h := range rr.hits { + score := h.Score + if score == 0 { + score = 1.0 + } + merged = append(merged, codeSearchResult{ + Repo: rr.repo, Path: h.Path, Snippet: h.Snippet, Score: score, HTMLURL: h.HTMLURL, + }) + } + } + + // Sort by score desc, then by repo+path for determinism. + sort.Slice(merged, func(i, j int) bool { + if merged[i].Score != merged[j].Score { + return merged[i].Score > merged[j].Score + } + if merged[i].Repo != merged[j].Repo { + return merged[i].Repo < merged[j].Repo + } + return merged[i].Path < merged[j].Path + }) + if len(merged) > args.Limit { + merged = merged[:args.Limit] + } + + out := map[string]any{ + "results": merged, + "partial": len(partialRepos) > 0, + } + if len(partialRepos) > 0 { + sort.Strings(partialRepos) + out["partial_repos"] = partialRepos + } + return textOK(out) +} diff --git a/internal/tools/code_search_test.go b/internal/tools/code_search_test.go index 4c541f4..962fd57 100644 --- a/internal/tools/code_search_test.go +++ b/internal/tools/code_search_test.go @@ -6,6 +6,7 @@ import ( "errors" "net/http" "net/http/httptest" + "strings" "testing" "gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist" @@ -65,9 +66,122 @@ func TestCodeSearchRequiresQ(t *testing.T) { assert.True(t, errors.Is(err, gitea.ErrValidation)) } -func TestCodeSearchFanOutNotYetImplemented(t *testing.T) { - tool := tools.NewCodeSearch(gitea.NewClient("http://unused", ""), allowlist.New([]string{"mathias"})) - _, err := tool.Call(context.Background(), json.RawMessage(`{"q":"foo","owner":"mathias"}`)) - require.Error(t, err) - assert.True(t, errors.Is(err, gitea.ErrValidation)) +func TestCodeSearchFanOutHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/api/v1/users/mathias/repos": + _, _ = w.Write([]byte(`[ + {"name":"infra","full_name":"mathias/infra","default_branch":"main"}, + {"name":"gitea-mcp","full_name":"mathias/gitea-mcp","default_branch":"main"} + ]`)) + case "/api/v1/repos/mathias/infra/search": + _, _ = w.Write([]byte(`{"data":[{"path":"main.go","snippet":"infra hit","html_url":"http://x/infra/main.go","score":2.0}],"ok":true}`)) + case "/api/v1/repos/mathias/gitea-mcp/search": + _, _ = w.Write([]byte(`{"data":[{"path":"cmd/main.go","snippet":"gitea-mcp hit","html_url":"http://x/gitea-mcp/main.go","score":1.0}],"ok":true}`)) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + tool := tools.NewCodeSearch(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"mathias"})) + out, err := tool.Call(context.Background(), json.RawMessage(`{"q":"hit","owner":"mathias"}`)) + require.NoError(t, err) + + var result struct { + Results []struct { + Repo string `json:"repo"` + Path string `json:"path"` + Snippet string `json:"snippet"` + } `json:"results"` + Partial bool `json:"partial"` + } + require.NoError(t, json.Unmarshal(out, &result)) + assert.False(t, result.Partial) + require.Len(t, result.Results, 2) + + repos := make([]string, 0, 2) + for _, r := range result.Results { + repos = append(repos, r.Repo) + } + assert.Contains(t, repos, "mathias/infra") + assert.Contains(t, repos, "mathias/gitea-mcp") +} + +func TestCodeSearchFanOutPartialFailure(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/api/v1/users/mathias/repos": + _, _ = w.Write([]byte(`[ + {"name":"infra","full_name":"mathias/infra","default_branch":"main"}, + {"name":"broken","full_name":"mathias/broken","default_branch":"main"} + ]`)) + case "/api/v1/repos/mathias/infra/search": + _, _ = w.Write([]byte(`{"data":[{"path":"main.go","snippet":"infra hit","html_url":"http://x/infra/main.go","score":1.0}],"ok":true}`)) + case "/api/v1/repos/mathias/broken/search": + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message":"internal error"}`)) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + tool := tools.NewCodeSearch(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"mathias"})) + out, err := tool.Call(context.Background(), json.RawMessage(`{"q":"hit","owner":"mathias"}`)) + require.NoError(t, err) + + var result struct { + Results []struct{ Repo string `json:"repo"` } `json:"results"` + Partial bool `json:"partial"` + PartialRepos []string `json:"partial_repos"` + } + require.NoError(t, json.Unmarshal(out, &result)) + assert.True(t, result.Partial) + require.Len(t, result.PartialRepos, 1) + assert.Equal(t, "mathias/broken", result.PartialRepos[0]) + require.Len(t, result.Results, 1) + assert.Equal(t, "mathias/infra", result.Results[0].Repo) +} + +func TestCodeSearchFanOutSortsByScore(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/api/v1/users/mathias/repos": + _, _ = w.Write([]byte(`[ + {"name":"alpha","full_name":"mathias/alpha","default_branch":"main"}, + {"name":"beta","full_name":"mathias/beta","default_branch":"main"} + ]`)) + case "/api/v1/repos/mathias/alpha/search": + // low score + _, _ = w.Write([]byte(`{"data":[{"path":"a.go","snippet":"low","html_url":"http://x/alpha/a.go","score":1.0}],"ok":true}`)) + case "/api/v1/repos/mathias/beta/search": + // high score + _, _ = w.Write([]byte(`{"data":[{"path":"b.go","snippet":"high","html_url":"http://x/beta/b.go","score":5.0}],"ok":true}`)) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + tool := tools.NewCodeSearch(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"mathias"})) + out, err := tool.Call(context.Background(), json.RawMessage(`{"q":"something","owner":"mathias"}`)) + require.NoError(t, err) + + var result struct { + Results []struct { + Snippet string `json:"snippet"` + Score float64 `json:"score"` + } `json:"results"` + } + require.NoError(t, json.Unmarshal(out, &result)) + require.Len(t, result.Results, 2) + // First result must be the high-score one + assert.True(t, result.Results[0].Score > result.Results[1].Score, + "expected results sorted by score desc, got %v then %v", + result.Results[0].Score, result.Results[1].Score) + assert.True(t, strings.Contains(result.Results[0].Snippet, "high")) }