diff --git a/ingestion/cmd/server/main.go b/ingestion/cmd/server/main.go index 0e2f886..df75079 100644 --- a/ingestion/cmd/server/main.go +++ b/ingestion/cmd/server/main.go @@ -16,6 +16,7 @@ import ( "github.com/mathiasbq/hyperguild/ingestion/internal/llm" "github.com/mathiasbq/hyperguild/ingestion/internal/mcp" "github.com/mathiasbq/hyperguild/ingestion/internal/oauth" + "github.com/mathiasbq/hyperguild/ingestion/internal/reranker" "github.com/mathiasbq/hyperguild/ingestion/internal/pipeline" "github.com/mathiasbq/hyperguild/ingestion/internal/watcher" ) @@ -77,6 +78,11 @@ func main() { } mcpSrv := mcp.NewServer(brainDir, &pipelineCfg, llmClient.Complete, answerComplete) + if rerankURL := os.Getenv("BRAIN_RERANKER_URL"); rerankURL != "" { + rerankModel := envOr("BRAIN_RERANKER_MODEL", "dengcao/Qwen3-Reranker-0.6B:F16") + mcpSrv = mcpSrv.WithReranker(reranker.New(rerankURL, rerankModel)) + logger.Info("brain reranker configured", "url", rerankURL, "model", rerankModel) + } mcpToken := os.Getenv("BRAIN_MCP_TOKEN") if mcpToken == "" { diff --git a/ingestion/internal/mcp/server.go b/ingestion/internal/mcp/server.go index fd011b0..3f4fc64 100644 --- a/ingestion/internal/mcp/server.go +++ b/ingestion/internal/mcp/server.go @@ -10,6 +10,7 @@ import ( "net/http" "github.com/mathiasbq/hyperguild/ingestion/internal/pipeline" + "github.com/mathiasbq/hyperguild/ingestion/internal/reranker" ) type request struct { @@ -37,6 +38,7 @@ type Server struct { pipeline pipeline.Config llm pipeline.CompleteFunc answerLLM pipeline.CompleteFunc // nil = brain_answer and brain_classify unavailable + reranker *reranker.Client // nil = no rerank, BM25 top-10 → LLM } // NewServer constructs a Server bound to brainDir. pipelineCfg supplies the @@ -50,6 +52,15 @@ func NewServer(brainDir string, pipelineCfg *pipeline.Config, llm pipeline.Compl return &Server{brainDir: brainDir, pipeline: cfg, llm: llm, answerLLM: answerLLM} } +// WithReranker installs an opt-in cross-encoder reranker. When set, +// brain_answer retrieves a wider BM25 candidate set and prunes it to +// the relevant ones before LLM synthesis. Returns the server for +// fluent chaining. +func (s *Server) WithReranker(r *reranker.Client) *Server { + s.reranker = r + return s +} + func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // MCP streamable HTTP: GET establishes the SSE stream for server-to-client events. if r.Method == http.MethodGet { diff --git a/ingestion/internal/mcp/tools_answer.go b/ingestion/internal/mcp/tools_answer.go index 9233eb8..8fdef68 100644 --- a/ingestion/internal/mcp/tools_answer.go +++ b/ingestion/internal/mcp/tools_answer.go @@ -6,9 +6,35 @@ import ( "fmt" "strings" + "github.com/mathiasbq/hyperguild/ingestion/internal/reranker" "github.com/mathiasbq/hyperguild/ingestion/internal/search" ) +// rerankResults scores each candidate's excerpt against the query and +// returns up to top results whose score is positive, preserving the +// caller's input order (BM25 rank) within the kept set. The reranker is +// a filter: ties are broken by BM25, not by the reranker's binary score. +func rerankResults(ctx context.Context, rr *reranker.Client, query string, results []search.Result, top int) ([]search.Result, error) { + docs := make([]string, len(results)) + for i, r := range results { + docs[i] = r.Excerpt + } + scores, err := rr.Score(ctx, query, docs) + if err != nil { + return nil, err + } + kept := make([]search.Result, 0, top) + for i, r := range results { + if scores[i] > 0 { + kept = append(kept, r) + } + if len(kept) == top { + break + } + } + return kept, nil +} + const ( answerSystemPrompt = `You are a knowledge assistant. Answer the question using ONLY the provided sources. Cite source file paths inline when referencing specific content. @@ -35,10 +61,22 @@ func (s *Server) brainAnswer(ctx context.Context, args json.RawMessage) (json.Ra return nil, fmt.Errorf("query is required") } - results, err := search.Query(s.brainDir, search.QueryOptions{Query: a.Query, Limit: 10}) + // With reranker disabled: BM25 top-10 straight to the LLM. + // With reranker enabled: BM25 top-20 → cross-encoder filter → top-5. + bm25Limit := 10 + if s.reranker != nil { + bm25Limit = 20 + } + results, err := search.Query(s.brainDir, search.QueryOptions{Query: a.Query, Limit: bm25Limit}) if err != nil { return nil, fmt.Errorf("search: %w", err) } + if s.reranker != nil && len(results) > 0 { + results, err = rerankResults(ctx, s.reranker, a.Query, results, 5) + if err != nil { + return nil, fmt.Errorf("rerank: %w", err) + } + } if len(results) == 0 { return json.Marshal(map[string]any{ "answer": "No relevant content found in brain.", diff --git a/ingestion/internal/mcp/tools_answer_test.go b/ingestion/internal/mcp/tools_answer_test.go index baa5f78..0262f45 100644 --- a/ingestion/internal/mcp/tools_answer_test.go +++ b/ingestion/internal/mcp/tools_answer_test.go @@ -3,14 +3,17 @@ package mcp_test import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "testing" "github.com/mathiasbq/hyperguild/ingestion/internal/mcp" "github.com/mathiasbq/hyperguild/ingestion/internal/pipeline" + "github.com/mathiasbq/hyperguild/ingestion/internal/reranker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -46,6 +49,55 @@ func callTool(t *testing.T, ts *httptest.Server, name string, arguments map[stri return out } +func TestBrainAnswer_RerankerFiltersBeforeLLM(t *testing.T) { + brainDir := t.TempDir() + wikiDir := filepath.Join(brainDir, "wiki") + require.NoError(t, os.MkdirAll(wikiDir, 0o755)) + // Two notes — both BM25-match the query, but only one is truly relevant. + require.NoError(t, os.WriteFile(filepath.Join(wikiDir, "good.md"), []byte( + "---\ntitle: Pass-rate Logging\n---\nPass-rate logging tracks skill invocations.", + ), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(wikiDir, "noise.md"), []byte( + "---\ntitle: Pass-rate Tangent\n---\nPass-rate appears here too but as a tangent.", + ), 0o644)) + + // Fake Ollama reranker: yes only when prompt contains "tracks skill invocations". + rrSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + yes := strings.Contains(string(raw), "tracks skill invocations") + ans := "no" + if yes { + ans = "yes" + } + _ = json.NewEncoder(w).Encode(map[string]any{"response": ans, "done": true}) + })) + defer rrSrv.Close() + + // LLM mock captures the rendered sources so we can assert what reached it. + var sawSources string + llm := func(_ context.Context, _, user string) (string, error) { + sawSources = user + return "answer text", nil + } + + srv := mcp.NewServer(brainDir, nil, nil, llm). + WithReranker(reranker.New(rrSrv.URL, "qwen3")) + ts := httptest.NewServer(srv) + defer ts.Close() + + rpc := callTool(t, ts, "brain_answer", map[string]any{"query": "pass-rate logging"}) + require.Nil(t, rpc["error"]) + + content := rpc["result"].(map[string]any)["content"].([]any)[0].(map[string]any)["text"].(string) + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(content), &result)) + sources := result["sources"].([]any) + require.Len(t, sources, 1, "reranker should drop noise.md") + assert.Equal(t, "wiki/good.md", sources[0]) + assert.Contains(t, sawSources, "good.md") + assert.NotContains(t, sawSources, "noise.md") +} + func TestBrainAnswer_NoLLM(t *testing.T) { srv := mcp.NewServer(t.TempDir(), nil, nil, nil) ts := httptest.NewServer(srv) diff --git a/ingestion/internal/reranker/reranker.go b/ingestion/internal/reranker/reranker.go new file mode 100644 index 0000000..30d7010 --- /dev/null +++ b/ingestion/internal/reranker/reranker.go @@ -0,0 +1,119 @@ +// Package reranker scores (query, document) pairs against a cross-encoder +// served by an Ollama-compatible backend. +// +// Wire format is Ollama's `/api/generate`. The model is prompted with the +// Qwen3-Reranker yes/no template — the canonical interface the model +// itself was trained against — and the first token of the response is +// treated as a binary relevance vote: "yes" → 1.0, anything else → 0.0. +// Ties are expected to be broken by the caller's primary retrieval score +// (e.g. BM25), so the binary signal is a filter rather than a ranking +// substitute. +package reranker + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// Client posts rerank requests to an Ollama-compatible endpoint. +type Client struct { + URL string + Model string + HTTP *http.Client +} + +// New constructs a Client. Returns nil when url is empty so callers can +// treat a missing BRAIN_RERANKER_URL as "feature disabled" with a single +// nil check. +func New(url, model string) *Client { + if url == "" { + return nil + } + return &Client{ + URL: strings.TrimRight(url, "/"), + Model: model, + HTTP: &http.Client{Timeout: 30 * time.Second}, + } +} + +// Score returns one [0, 1] relevance score per input document, parallel +// to the input order. Each (query, doc) pair is scored independently — +// Qwen3-Reranker is a cross-encoder and expects per-pair calls. +func (c *Client) Score(ctx context.Context, query string, docs []string) ([]float64, error) { + out := make([]float64, len(docs)) + for i, doc := range docs { + s, err := c.scoreOne(ctx, query, doc) + if err != nil { + return nil, fmt.Errorf("rerank doc %d: %w", i, err) + } + out[i] = s + } + return out, nil +} + +func (c *Client) scoreOne(ctx context.Context, query, doc string) (float64, error) { + prompt := buildPrompt(query, doc) + reqBody, _ := json.Marshal(map[string]any{ + "model": c.Model, + "prompt": prompt, + "stream": false, + "options": map[string]any{ + "num_predict": 4, + "temperature": 0, + }, + }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + c.URL+"/api/generate", bytes.NewReader(reqBody)) + if err != nil { + return 0, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.HTTP.Do(req) + if err != nil { + return 0, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode/100 != 2 { + body, _ := io.ReadAll(resp.Body) + return 0, fmt.Errorf("status %d: %s", resp.StatusCode, string(body)) + } + var out struct { + Response string `json:"response"` + } + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return 0, err + } + return parseYesNo(out.Response), nil +} + +// buildPrompt assembles the Qwen3-Reranker chat template. Kept verbatim +// because the model was trained on this exact wording. +func buildPrompt(query, doc string) string { + return "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n" + + "<|im_start|>user\n: Given a web search query, retrieve relevant passages that answer the query\n" + + ": " + query + "\n" + + ": " + doc + "<|im_end|>\n" + + "<|im_start|>assistant\n\n\n\n\n" +} + +// parseYesNo extracts the first meaningful token from response and +// returns 1.0 when it starts with "yes" (case-insensitive), 0.0 otherwise. +// Any leading whitespace, `` block, or punctuation is skipped. +func parseYesNo(s string) float64 { + s = strings.TrimSpace(s) + // Strip any `` block the model may emit even with empty thinking. + if idx := strings.Index(s, ""); idx != -1 { + s = strings.TrimSpace(s[idx+len(""):]) + } + s = strings.ToLower(s) + if strings.HasPrefix(s, "yes") { + return 1.0 + } + return 0.0 +} diff --git a/ingestion/internal/reranker/reranker_test.go b/ingestion/internal/reranker/reranker_test.go new file mode 100644 index 0000000..d82e1e8 --- /dev/null +++ b/ingestion/internal/reranker/reranker_test.go @@ -0,0 +1,119 @@ +package reranker_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mathiasbq/hyperguild/ingestion/internal/reranker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeOllama responds to /api/generate based on a per-document +// {needle → answer} map: if the prompt contains the needle, returns +// the mapped answer. +type fakeOllama struct { + t *testing.T + answers map[string]string // needle → "yes" or "no" + calls int + lastBody map[string]any +} + +func (f *fakeOllama) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(f.t, http.MethodPost, r.Method) + require.Equal(f.t, "/api/generate", r.URL.Path) + body, err := io.ReadAll(r.Body) + require.NoError(f.t, err) + var p map[string]any + require.NoError(f.t, json.Unmarshal(body, &p)) + f.calls++ + f.lastBody = p + prompt := p["prompt"].(string) + answer := "no" + for needle, a := range f.answers { + if strings.Contains(prompt, needle) { + answer = a + break + } + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "model": p["model"], "response": answer, "done": true, + }) + }) +} + +func TestNew_EmptyURLReturnsNil(t *testing.T) { + assert.Nil(t, reranker.New("", "model")) +} + +func TestScore_YesAndNoOrdered(t *testing.T) { + f := &fakeOllama{t: t, answers: map[string]string{ + "alpha doc": "yes", + "beta doc": "no", + "gamma doc": "yes", + }} + srv := httptest.NewServer(f.handler()) + defer srv.Close() + c := reranker.New(srv.URL, "test-model") + require.NotNil(t, c) + + scores, err := c.Score(context.Background(), "what is alpha", + []string{"alpha doc body", "beta doc body", "gamma doc body"}) + require.NoError(t, err) + require.Len(t, scores, 3) + assert.Equal(t, 1.0, scores[0]) + assert.Equal(t, 0.0, scores[1]) + assert.Equal(t, 1.0, scores[2]) + assert.Equal(t, 3, f.calls) +} + +func TestScore_SendsCorrectShape(t *testing.T) { + f := &fakeOllama{t: t, answers: map[string]string{"hello": "yes"}} + srv := httptest.NewServer(f.handler()) + defer srv.Close() + c := reranker.New(srv.URL, "qwen3-rerank") + _, err := c.Score(context.Background(), "greeting", []string{"hello world"}) + require.NoError(t, err) + assert.Equal(t, "qwen3-rerank", f.lastBody["model"]) + prompt := f.lastBody["prompt"].(string) + assert.Contains(t, prompt, "greeting") + assert.Contains(t, prompt, "hello world") + assert.Contains(t, prompt, `"yes" or "no"`) +} + +func TestScore_HandlesAmbiguousResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"response": "maybe — unclear", "done": true}) + })) + defer srv.Close() + c := reranker.New(srv.URL, "m") + scores, err := c.Score(context.Background(), "q", []string{"d"}) + require.NoError(t, err) + // Anything that does not start with "yes" (case-insensitive, after + // whitespace/think trim) is treated as "no" = 0. + assert.Equal(t, []float64{0}, scores) +} + +func TestScore_EmptyDocsReturnsEmpty(t *testing.T) { + c := reranker.New("http://127.0.0.1:1", "m") + scores, err := c.Score(context.Background(), "q", nil) + require.NoError(t, err) + assert.Empty(t, scores) +} + +func TestScore_UpstreamErrorPropagates(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + c := reranker.New(srv.URL, "m") + _, err := c.Score(context.Background(), "q", []string{"d"}) + require.Error(t, err) +}