feat(routing): cmd/routing binary
Wires Config → LiteLLMExecutor → Router → four skills (review, debug, retrospective, trainer) → Registry → MCP server with bearer auth and /healthz. Each skill's CompleteFunc is wrapped so the Router decides local-vs-Claude per call and logs every decision to the brain /mcp. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
116
cmd/routing/main.go
Normal file
116
cmd/routing/main.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/mathiasbq/supervisor/internal/config"
|
||||
iexec "github.com/mathiasbq/supervisor/internal/exec"
|
||||
"github.com/mathiasbq/supervisor/internal/mcp"
|
||||
"github.com/mathiasbq/supervisor/internal/registry"
|
||||
"github.com/mathiasbq/supervisor/internal/routing"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/debug"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/retrospective"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/review"
|
||||
"github.com/mathiasbq/supervisor/internal/skills/trainer"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
cfg, err := config.LoadRouting()
|
||||
if err != nil {
|
||||
logger.Error("config load failed", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
configDir := envOr("SUPERVISOR_CONFIG_DIR", "/app/config/supervisor")
|
||||
mustRead := func(path string) string {
|
||||
b, err := os.ReadFile(configDir + "/" + path)
|
||||
if err != nil {
|
||||
logger.Error("read prompt failed", "path", path, "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
llm := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0)
|
||||
|
||||
router := &routing.Router{
|
||||
Fetcher: routing.NewFetcher(cfg.BrainURL, "7d", time.Duration(cfg.PassRateTTLSeconds)*time.Second),
|
||||
Logger: routing.NewLogger(cfg.BrainURL),
|
||||
Policy: routing.Policy{Floor: cfg.RouteLocalFloor, Ceil: cfg.RouteLocalCeil},
|
||||
LocalModel: cfg.LocalModel,
|
||||
ClaudeModel: cfg.ClaudeModel,
|
||||
Complete: llm.Complete,
|
||||
}
|
||||
|
||||
// Skill packages call CompleteFunc(ctx, model, system, user) — no session_id
|
||||
// or project_root in the signature. Rather than modifying every skill's API
|
||||
// (and inflating Plan 6's blast radius), the routing pod logs every decision
|
||||
// under a fixed session_id "_routing". Operators query
|
||||
// `GET /pass-rate?skill=_routing&window=...` to inspect routing health.
|
||||
const routingSessionID = "_routing"
|
||||
wrap := func(skillName string) routing.CompleteFunc {
|
||||
return func(ctx context.Context, _, system, user string) (string, int64, error) {
|
||||
// The model param is ignored: the router picks the model based on policy.
|
||||
return router.Run(ctx, routing.RunInput{
|
||||
Skill: skillName,
|
||||
System: system,
|
||||
User: user,
|
||||
SessionID: routingSessionID,
|
||||
ProjectRoot: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
reg := registry.New()
|
||||
reg.Register(review.New(review.Config{
|
||||
SkillPrompt: mustRead("review.md"),
|
||||
DefaultModel: cfg.LocalModel,
|
||||
CompleteFunc: review.CompleteFunc(wrap("review")),
|
||||
}))
|
||||
reg.Register(debug.New(debug.Config{
|
||||
SkillPrompt: mustRead("debug.md"),
|
||||
DefaultModel: cfg.LocalModel,
|
||||
CompleteFunc: debug.CompleteFunc(wrap("debug")),
|
||||
}))
|
||||
reg.Register(retrospective.New(retrospective.Config{
|
||||
SkillPrompt: mustRead("retrospective.md"),
|
||||
DefaultModel: cfg.LocalModel,
|
||||
CompleteFunc: retrospective.CompleteFunc(wrap("retrospective")),
|
||||
}))
|
||||
reg.Register(trainer.New(trainer.Config{
|
||||
ReaderPrompt: mustRead("trainer-reader.md"),
|
||||
WriterPrompt: mustRead("trainer-writer.md"),
|
||||
DefaultModel: cfg.LocalModel,
|
||||
CompleteFunc: trainer.CompleteFunc(wrap("trainer")),
|
||||
}))
|
||||
|
||||
srv := mcp.NewServer(reg, cfg.MCPAuthToken)
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/mcp", srv)
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
addr := ":" + cfg.Port
|
||||
logger.Info("routing pod starting", "addr", addr,
|
||||
"local", cfg.LocalModel, "claude", cfg.ClaudeModel,
|
||||
"floor", cfg.RouteLocalFloor, "ceil", cfg.RouteLocalCeil)
|
||||
if err := http.ListenAndServe(addr, mux); err != nil { //nolint:gosec
|
||||
logger.Error("server stopped", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func envOr(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
123
cmd/routing/main_test.go
Normal file
123
cmd/routing/main_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package main_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRoutingPodEndToEnd boots the binary against fake LiteLLM + brain servers,
|
||||
// calls tools/list and one tools/call, and verifies the brain saw a session_log POST.
|
||||
func TestRoutingPodEndToEnd(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("end-to-end binary boot")
|
||||
}
|
||||
|
||||
var brainHits int
|
||||
llm := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{"message": map[string]any{"role": "assistant", "content": "stub"}}},
|
||||
})
|
||||
}))
|
||||
defer llm.Close()
|
||||
|
||||
brain := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/pass-rate":
|
||||
brainHits++
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"pass_rate": 0.95})
|
||||
case "/mcp":
|
||||
brainHits++
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"jsonrpc": "2.0", "id": 1, "result": map[string]any{}})
|
||||
}
|
||||
}))
|
||||
defer brain.Close()
|
||||
|
||||
bin := buildRouting(t)
|
||||
cmd := exec.Command(bin)
|
||||
cmd.Env = append(cmd.Env,
|
||||
"ROUTING_PORT=33310",
|
||||
"LITELLM_BASE_URL="+llm.URL,
|
||||
"LITELLM_API_KEY=stub",
|
||||
"BRAIN_URL="+brain.URL,
|
||||
"SUPERVISOR_CONFIG_DIR=../../config/supervisor",
|
||||
"PATH="+osPath(),
|
||||
)
|
||||
require.NoError(t, cmd.Start())
|
||||
t.Cleanup(func() { _ = cmd.Process.Kill() })
|
||||
|
||||
require.NoError(t, waitForPort(t, "127.0.0.1:33310", 5*time.Second))
|
||||
|
||||
resp := mcpCall(t, "http://127.0.0.1:33310/mcp", `{"jsonrpc":"2.0","id":1,"method":"tools/list"}`)
|
||||
assert.Contains(t, resp, `"review"`)
|
||||
assert.Contains(t, resp, `"debug"`)
|
||||
assert.Contains(t, resp, `"retrospective"`)
|
||||
assert.Contains(t, resp, `"trainer"`)
|
||||
|
||||
resp = mcpCall(t, "http://127.0.0.1:33310/mcp", `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"review","arguments":{"project_root":"/tmp","files":["README.md"]}}}`)
|
||||
_ = resp // shape varies by skill; we only need a 200
|
||||
|
||||
// Wait briefly for the async session_log to land.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) && brainHits < 2 {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
assert.GreaterOrEqual(t, brainHits, 2, "expected at least one /pass-rate hit and one /mcp session_log hit")
|
||||
}
|
||||
|
||||
func buildRouting(t *testing.T) string {
|
||||
t.Helper()
|
||||
bin := t.TempDir() + "/routing"
|
||||
out, err := exec.Command("go", "build", "-o", bin, "github.com/mathiasbq/supervisor/cmd/routing").CombinedOutput()
|
||||
require.NoError(t, err, "build failed: %s", out)
|
||||
return bin
|
||||
}
|
||||
|
||||
func waitForPort(_ *testing.T, addr string, dur time.Duration) error {
|
||||
deadline := time.Now().Add(dur)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := http.Get("http://" + addr + "/healthz") //nolint:noctx
|
||||
if err == nil {
|
||||
_ = c.Body.Close()
|
||||
return nil
|
||||
}
|
||||
conn, err := http.NewRequest(http.MethodPost, "http://"+addr+"/mcp", strings.NewReader(`{}`))
|
||||
if err == nil {
|
||||
r, err := http.DefaultClient.Do(conn)
|
||||
if err == nil {
|
||||
_ = r.Body.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func mcpCall(t *testing.T, url, body string) string {
|
||||
t.Helper()
|
||||
r, err := http.Post(url, "application/json", strings.NewReader(body)) //nolint:noctx
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func osPath() string {
|
||||
for _, e := range append([]string{}, exec.Command("env").Env...) {
|
||||
if strings.HasPrefix(e, "PATH=") {
|
||||
return strings.TrimPrefix(e, "PATH=")
|
||||
}
|
||||
}
|
||||
return "/usr/bin:/bin"
|
||||
}
|
||||
Reference in New Issue
Block a user