From ce4559273016943c9f52a156ac8b4f8a9cc10777 Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Wed, 22 Apr 2026 16:19:09 +0200 Subject: [PATCH] refactor: replace orchestrator/verifier chain with direct LiteLLM calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the three-layer Claude subprocess orchestration (local model → Claude verifier → cloud escalation). Skills now call LiteLLM directly and return plain text to Claude Code, which decides what to do with it. - Delete executor, orchestrator, verifier, result, attempts packages - Simplify LiteLLMExecutor: Run(Request)→Result becomes Complete(model,sys,user)→(string,int64,error) - Replace ExecutorFn with CompleteFunc in all 6 skill configs - Rewrite all skill handlers to call Complete and return {"text","model","duration_ms"} - Simplify config/models: remove Verifier/LlamaSwapURL, add ModelFor - Bump version to v0.5.0 Co-Authored-By: Claude Sonnet 4.6 --- cmd/supervisor/main.go | 68 ++---- config/models.yaml | 20 +- internal/config/models.go | 27 +-- internal/config/models_test.go | 39 +--- internal/exec/executor.go | 111 ---------- internal/exec/executor_test.go | 132 ------------ internal/exec/litellm.go | 50 ++--- internal/exec/litellm_test.go | 54 ++--- internal/exec/orchestrator.go | 197 ------------------ internal/exec/orchestrator_test.go | 151 -------------- internal/exec/result.go | 66 ------ internal/exec/result_test.go | 79 ------- internal/exec/verifier.go | 99 --------- internal/exec/verifier_test.go | 74 ------- internal/session/attempts.go | 26 --- internal/session/attempts_test.go | 37 ---- internal/skills/debug/handlers.go | 27 +-- internal/skills/debug/handlers_test.go | 22 +- internal/skills/debug/skill.go | 11 +- internal/skills/retrospective/handlers.go | 32 ++- .../skills/retrospective/handlers_test.go | 24 +-- internal/skills/retrospective/skill.go | 15 +- internal/skills/review/handlers.go | 28 +-- internal/skills/review/handlers_test.go | 22 +- internal/skills/review/skill.go | 11 +- internal/skills/spec/handlers.go | 28 +-- internal/skills/spec/handlers_test.go | 22 +- internal/skills/spec/skill.go | 11 +- internal/skills/tdd/handlers.go | 50 +++-- internal/skills/tdd/handlers_test.go | 54 +++-- internal/skills/tdd/skill.go | 14 +- internal/skills/trainer/handlers.go | 51 ++--- internal/skills/trainer/handlers_test.go | 35 ++-- internal/skills/trainer/skill.go | 11 +- 34 files changed, 266 insertions(+), 1432 deletions(-) delete mode 100644 internal/exec/executor.go delete mode 100644 internal/exec/executor_test.go delete mode 100644 internal/exec/orchestrator.go delete mode 100644 internal/exec/orchestrator_test.go delete mode 100644 internal/exec/result.go delete mode 100644 internal/exec/result_test.go delete mode 100644 internal/exec/verifier.go delete mode 100644 internal/exec/verifier_test.go delete mode 100644 internal/session/attempts.go delete mode 100644 internal/session/attempts_test.go diff --git a/cmd/supervisor/main.go b/cmd/supervisor/main.go index 27c8c06..c32d500 100644 --- a/cmd/supervisor/main.go +++ b/cmd/supervisor/main.go @@ -37,12 +37,6 @@ func main() { os.Exit(1) } - systemPrompt, err := os.ReadFile(cfg.ConfigDir + "/CLAUDE.md") - if err != nil { - logger.Error("read supervisor CLAUDE.md", "path", cfg.ConfigDir+"/CLAUDE.md", "err", err) - os.Exit(1) - } - protocolsPrompt, err := os.ReadFile(cfg.ConfigDir + "/protocols.md") if err != nil { logger.Error("read protocols.md", "path", cfg.ConfigDir+"/protocols.md", "err", err) @@ -95,40 +89,7 @@ func main() { os.Exit(1) } - claudeExec := iexec.New(iexec.Config{ - SystemPrompt: string(systemPrompt), - LiteLLMBaseURL: cfg.LiteLLMBaseURL, - LiteLLMAPIKey: cfg.LiteLLMAPIKey, - }) - litellmExec := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0) - verifier := iexec.NewVerifier("", models.Verifier(), 0) - - buildOrch := func(skill string) func(ctx context.Context, req iexec.Request) (iexec.Result, error) { - return func(ctx context.Context, req iexec.Request) (iexec.Result, error) { - rawChain := models.ChainFor(skill, req.Model) - chain := make([]iexec.ChainEntry, len(rawChain)) - for i, m := range rawChain { - chain[i] = iexec.EntryFor(m) - } - attempts := make([]iexec.AttemptRecord, 0, len(chain)) - orch := iexec.NewOrchestrator(chain, litellmExec.Run, claudeExec.Run, verifier, models.LlamaSwapURL(), &attempts) - result, err := orch.Run(ctx, req) - result.Attempts = attempts // attach orchestration metadata before returning - // Log per-attempt verdicts so pass rates are visible in pod logs. - for i, a := range attempts { - logger.Info("chain attempt", - "skill", skill, - "attempt", i+1, - "model", a.Model, - "tier", a.Tier, - "verdict", a.Verdict, - "duration_ms", a.DurationMs, - "warm", a.WarmStart, - ) - } - return result, err - } - } + litellm := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0) tierFn := func(ctx context.Context) tier.Info { return tier.Detect(ctx, "https://api.anthropic.com", cfg.LiteLLMBaseURL) @@ -136,10 +97,9 @@ func main() { reg := registry.New() reg.Register(tdd.New(tdd.Config{ - SystemPrompt: string(systemPrompt), SkillPrompt: prependProtocols(tddPrompt), - DefaultModel: models.ChainFor("tdd", "")[0], - ExecutorFn: buildOrch("tdd"), + DefaultModel: models.ModelFor("tdd", ""), + CompleteFunc: litellm.Complete, SessionsDir: cfg.SessionsDir, IngestBaseURL: cfg.IngestBaseURL, })) @@ -154,36 +114,36 @@ func main() { })) reg.Register(retrospective.New(retrospective.Config{ SkillPrompt: prependProtocols(retroPrompt), - DefaultModel: models.ChainFor("retrospective", "")[0], + DefaultModel: models.ModelFor("retrospective", ""), SessionsDir: cfg.SessionsDir, - ExecutorFn: buildOrch("retrospective"), + CompleteFunc: litellm.Complete, })) reg.Register(review.New(review.Config{ SkillPrompt: prependProtocols(reviewPrompt), - DefaultModel: models.ChainFor("review", "")[0], - ExecutorFn: buildOrch("review"), + DefaultModel: models.ModelFor("review", ""), + CompleteFunc: litellm.Complete, SessionsDir: cfg.SessionsDir, IngestBaseURL: cfg.IngestBaseURL, })) reg.Register(skilldebug.New(skilldebug.Config{ SkillPrompt: prependProtocols(debugPrompt), - DefaultModel: models.ChainFor("debug", "")[0], - ExecutorFn: buildOrch("debug"), + DefaultModel: models.ModelFor("debug", ""), + CompleteFunc: litellm.Complete, SessionsDir: cfg.SessionsDir, IngestBaseURL: cfg.IngestBaseURL, })) reg.Register(spec.New(spec.Config{ SkillPrompt: prependProtocols(specPrompt), - DefaultModel: models.ChainFor("spec", "")[0], - ExecutorFn: buildOrch("spec"), + DefaultModel: models.ModelFor("spec", ""), + CompleteFunc: litellm.Complete, SessionsDir: cfg.SessionsDir, IngestBaseURL: cfg.IngestBaseURL, })) reg.Register(trainer.New(trainer.Config{ ReaderPrompt: prependProtocols(trainerReaderPrompt), WriterPrompt: prependProtocols(trainerWriterPrompt), - DefaultModel: models.ChainFor("trainer", "")[0], - ExecutorFn: buildOrch("trainer"), + DefaultModel: models.ModelFor("trainer", ""), + CompleteFunc: litellm.Complete, SessionsDir: cfg.SessionsDir, BrainDir: cfg.BrainDir, })) @@ -193,7 +153,7 @@ func main() { mux.Handle("/mcp", srv) addr := ":" + cfg.Port - logger.Info("supervisor starting", "addr", addr, "version", "v0.4.0") + logger.Info("supervisor starting", "addr", addr, "version", "v0.5.0") if err := http.ListenAndServe(addr, mux); err != nil { logger.Error("server stopped", "err", err) os.Exit(1) diff --git a/config/models.yaml b/config/models.yaml index 9a1ea8a..80a3014 100644 --- a/config/models.yaml +++ b/config/models.yaml @@ -1,41 +1,25 @@ -# Model routing chains — three-layer priority: -# 1. model param in MCP tool call (caller override — collapses to single entry, no escalation) -# 2. per-skill chain here -# 3. default_chain fallback - -verifier: claude-sonnet-4-6 # fixed verifier for all local tiers - -llama_swap_url: http://koala:8080 # for warm-state probing +# Model selection — first entry per skill is used. +# Override per-call by passing model in the MCP tool args. default_chain: - ollama/qwen3-coder-30b-tuned - - claude-sonnet-4-6 skills: tdd: chain: - ollama/qwen3-coder-30b-tuned - - claude-sonnet-4-6 review: chain: - ollama/devstral-tuned - - ollama/gemma4 - - claude-sonnet-4-6 debug: chain: - ollama/deepseek-r1-tuned - - claude-sonnet-4-6 spec: chain: - ollama/phi4 - - ollama/gemma4 - - claude-sonnet-4-6 - - claude-opus-4-6 retrospective: chain: - ollama/qwen3-coder-30b-tuned - - claude-sonnet-4-6 trainer: chain: - ollama/qwen3-coder-30b-tuned - - claude-sonnet-4-6 diff --git a/internal/config/models.go b/internal/config/models.go index 8b3503b..fefed6c 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -12,8 +12,6 @@ type skillChain struct { } type modelsFile struct { - Verifier string `yaml:"verifier"` - LlamaSwapURL string `yaml:"llama_swap_url"` DefaultChain []string `yaml:"default_chain"` Skills map[string]skillChain `yaml:"skills"` } @@ -34,23 +32,18 @@ func LoadModels(path string) (Models, error) { return Models{data: f}, nil } -// Verifier returns the model name to use for all local-tier output verification. -func (m Models) Verifier() string { return m.data.Verifier } - -// LlamaSwapURL returns the llama-swap base URL for warm-state probing. -func (m Models) LlamaSwapURL() string { return m.data.LlamaSwapURL } - -// ChainFor returns the ordered list of model names for a skill. -// If override is non-empty, returns a single-entry chain (no escalation). -// Falls back to default_chain when the skill has no explicit entry. -func (m Models) ChainFor(skill, override string) []string { +// ModelFor returns the primary model to use for a skill. +// If override is non-empty, it is returned directly. +// Falls back to default_chain[0] when the skill has no explicit entry. +func (m Models) ModelFor(skill, override string) string { if override != "" { - return []string{override} + return override } if sc, ok := m.data.Skills[skill]; ok && len(sc.Chain) > 0 { - return sc.Chain + return sc.Chain[0] } - out := make([]string, len(m.data.DefaultChain)) - copy(out, m.data.DefaultChain) - return out + if len(m.data.DefaultChain) > 0 { + return m.data.DefaultChain[0] + } + return "" } diff --git a/internal/config/models_test.go b/internal/config/models_test.go index b37a525..face928 100644 --- a/internal/config/models_test.go +++ b/internal/config/models_test.go @@ -11,9 +11,6 @@ import ( ) const testYAML = ` -verifier: claude-sonnet-4-6 -llama_swap_url: http://koala:8080 - default_chain: - ollama/qwen3-coder-30b-tuned - claude-sonnet-4-6 @@ -37,44 +34,20 @@ func writeModels(t *testing.T, content string) string { return f } -func TestModelsVerifier(t *testing.T) { +func TestModelsModelForSkillWithEntry(t *testing.T) { m, err := config.LoadModels(writeModels(t, testYAML)) require.NoError(t, err) - assert.Equal(t, "claude-sonnet-4-6", m.Verifier()) + assert.Equal(t, "ollama/devstral-tuned", m.ModelFor("review", "")) } -func TestModelsLlamaSwapURL(t *testing.T) { +func TestModelsModelForDefaultFallback(t *testing.T) { m, err := config.LoadModels(writeModels(t, testYAML)) require.NoError(t, err) - assert.Equal(t, "http://koala:8080", m.LlamaSwapURL()) + assert.Equal(t, "ollama/qwen3-coder-30b-tuned", m.ModelFor("trainer", "")) } -func TestModelsChainForSkillOverride(t *testing.T) { +func TestModelsModelForCallerOverride(t *testing.T) { m, err := config.LoadModels(writeModels(t, testYAML)) require.NoError(t, err) - - chain := m.ChainFor("review", "") - require.Len(t, chain, 3) - assert.Equal(t, "ollama/devstral-tuned", chain[0]) - assert.Equal(t, "ollama/gemma4", chain[1]) - assert.Equal(t, "claude-sonnet-4-6", chain[2]) -} - -func TestModelsChainForDefaultFallback(t *testing.T) { - m, err := config.LoadModels(writeModels(t, testYAML)) - require.NoError(t, err) - - chain := m.ChainFor("trainer", "") // not in skills map - require.Len(t, chain, 2) - assert.Equal(t, "ollama/qwen3-coder-30b-tuned", chain[0]) - assert.Equal(t, "claude-sonnet-4-6", chain[1]) -} - -func TestModelsChainForCallerOverride(t *testing.T) { - m, err := config.LoadModels(writeModels(t, testYAML)) - require.NoError(t, err) - - chain := m.ChainFor("review", "claude-opus-4-6") - require.Len(t, chain, 1) - assert.Equal(t, "claude-opus-4-6", chain[0]) + assert.Equal(t, "claude-opus-4-6", m.ModelFor("review", "claude-opus-4-6")) } diff --git a/internal/exec/executor.go b/internal/exec/executor.go deleted file mode 100644 index cbc7e7b..0000000 --- a/internal/exec/executor.go +++ /dev/null @@ -1,111 +0,0 @@ -package exec - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "os" - "os/exec" - "strings" - "time" -) - -// Config holds executor configuration. -type Config struct { - ClaudeBinary string // path to claude binary, defaults to "claude" - SystemPrompt string // contents of supervisor CLAUDE.md - Timeout time.Duration // per-invocation timeout, default 120s - LiteLLMBaseURL string // passed to Claude so it can delegate to Ollama - LiteLLMAPIKey string // passed to Claude for LiteLLM auth -} - -// Request is the input to a single supervisor invocation. -type Request struct { - SkillPrompt string // skill-specific discipline (e.g. tdd.md contents) - TaskPrompt string // the specific task (phase, project_root, spec, model) - Model string // resolved model name, passed in task prompt - Tools string // comma-separated allowed tools, default "Bash,Read,Write" -} - -// Executor spawns a claude instance and captures its structured JSON output. -type Executor struct { - cfg Config -} - -func New(cfg Config) *Executor { - if cfg.ClaudeBinary == "" { - cfg.ClaudeBinary = "claude" - } - if cfg.Timeout == 0 { - cfg.Timeout = 120 * time.Second - } - return &Executor{cfg: cfg} -} - -func (e *Executor) Run(ctx context.Context, req Request) (Result, error) { - ctx, cancel := context.WithTimeout(ctx, e.cfg.Timeout) - defer cancel() - - tools := req.Tools - if tools == "" { - tools = "Bash,Read,Write" - } - - // Build the full prompt: system rules + skill rules + infra context + task. - // LITELLM_API_KEY is injected as a subprocess env var, not in the prompt, - // to prevent it appearing in error log output. - litellmCtx := fmt.Sprintf("LITELLM_BASE_URL: %s", e.cfg.LiteLLMBaseURL) - prompt := strings.Join([]string{ - e.cfg.SystemPrompt, - "---", - req.SkillPrompt, - "---", - litellmCtx, - "---", - req.TaskPrompt, - }, "\n\n") - - args := []string{ - "--print", - "--permission-mode", "bypassPermissions", - "--tools", tools, - "--json-schema", Schema, - "--output-format", "json", - } - if strings.HasPrefix(req.Model, "claude-") { - args = append(args, "--model", req.Model) - } - args = append(args, prompt) - - cmd := exec.CommandContext(ctx, e.cfg.ClaudeBinary, args...) - cmd.Env = append(os.Environ(), "LITELLM_API_KEY="+e.cfg.LiteLLMAPIKey) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - if ctx.Err() != nil { - return Result{}, fmt.Errorf("timeout after %s", e.cfg.Timeout) - } - return Result{}, fmt.Errorf("claude exited with error: %w — stderr: %s", err, stderr.String()) - } - - // --output-format json wraps the response in an envelope; structured output - // from --json-schema is in the "structured_output" field. - var envelope struct { - StructuredOutput *Result `json:"structured_output"` - IsError bool `json:"is_error"` - Result string `json:"result"` // fallback text result for error messages - } - if err := json.Unmarshal(stdout.Bytes(), &envelope); err != nil { - return Result{}, fmt.Errorf("parse envelope JSON: %w — raw: %s — stderr: %s", err, stdout.String(), stderr.String()) - } - if envelope.StructuredOutput == nil { - return Result{}, fmt.Errorf("no structured_output in response — result: %s — stderr: %s", envelope.Result, stderr.String()) - } - if err := envelope.StructuredOutput.Validate(); err != nil { - return Result{}, fmt.Errorf("invalid result: %w", err) - } - return *envelope.StructuredOutput, nil -} diff --git a/internal/exec/executor_test.go b/internal/exec/executor_test.go deleted file mode 100644 index 6610573..0000000 --- a/internal/exec/executor_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package exec_test - -import ( - "context" - "os" - "path/filepath" - "testing" - "time" - - iexec "github.com/mathiasbq/supervisor/internal/exec" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// fakeClaudePath writes a shell script that prints fixed output and returns its path. -func fakeClaudePath(t *testing.T, output string, exitCode int) string { - t.Helper() - dir := t.TempDir() - script := filepath.Join(dir, "claude") - var content string - if exitCode != 0 { - content = "#!/bin/sh\necho 'error' >&2\nexit 1\n" - } else { - content = "#!/bin/sh\necho '" + output + "'\n" - } - require.NoError(t, os.WriteFile(script, []byte(content), 0755)) - return script -} - -func TestExecutorParsesValidResult(t *testing.T) { - // Fake claude emits the --output-format json envelope that the real CLI produces. - // The executor extracts the result from the "structured_output" field. - envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"red","skill":"tdd","file_path":"/tmp/x_test.go","runner_output":"FAIL","verified":true,"model_used":"self","message":"ok"}}` - claude := fakeClaudePath(t, envelope, 0) - - ex := iexec.New(iexec.Config{ - ClaudeBinary: claude, - SystemPrompt: "you are a supervisor", - Timeout: 5 * time.Second, - }) - - result, err := ex.Run(context.Background(), iexec.Request{ - SkillPrompt: "tdd rules", - TaskPrompt: "run red phase", - }) - require.NoError(t, err) - assert.Equal(t, "pass", result.Status) - assert.True(t, result.Verified) -} - -func TestExecutorReturnsErrorOnNonZeroExit(t *testing.T) { - claude := fakeClaudePath(t, "", 1) - - ex := iexec.New(iexec.Config{ - ClaudeBinary: claude, - SystemPrompt: "you are a supervisor", - Timeout: 5 * time.Second, - }) - - _, err := ex.Run(context.Background(), iexec.Request{TaskPrompt: "fail"}) - assert.Error(t, err) -} - -func TestExecutorTimesOut(t *testing.T) { - dir := t.TempDir() - script := filepath.Join(dir, "claude") - require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\nsleep 60\n"), 0755)) - - ex := iexec.New(iexec.Config{ - ClaudeBinary: script, - SystemPrompt: "you are a supervisor", - Timeout: 100 * time.Millisecond, - }) - - _, err := ex.Run(context.Background(), iexec.Request{TaskPrompt: "slow"}) - assert.ErrorContains(t, err, "timeout") -} - -func TestExecutorPassesModelFlagForCloudModel(t *testing.T) { - // The script captures its args to a temp file so we can assert --model was passed. - argsFile := filepath.Join(t.TempDir(), "args.txt") - envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"review","skill":"review","file_path":"","runner_output":"","verified":true,"model_used":"claude-sonnet-4-6","message":"ok"}}` - - dir := t.TempDir() - script := filepath.Join(dir, "claude") - content := "#!/bin/sh\necho \"$@\" > " + argsFile + "\necho '" + envelope + "'\n" - require.NoError(t, os.WriteFile(script, []byte(content), 0755)) - - ex := iexec.New(iexec.Config{ - ClaudeBinary: script, - SystemPrompt: "sys", - Timeout: 5 * time.Second, - }) - - _, err := ex.Run(context.Background(), iexec.Request{ - SkillPrompt: "review rules", - TaskPrompt: "do review", - Model: "claude-sonnet-4-6", - }) - require.NoError(t, err) - - argsData, err := os.ReadFile(argsFile) - require.NoError(t, err) - assert.Contains(t, string(argsData), "--model claude-sonnet-4-6") -} - -func TestExecutorSkipsModelFlagForLocalModel(t *testing.T) { - argsFile := filepath.Join(t.TempDir(), "args.txt") - envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"review","skill":"review","file_path":"","runner_output":"","verified":true,"model_used":"ollama/devstral","message":"ok"}}` - - dir := t.TempDir() - script := filepath.Join(dir, "claude") - content := "#!/bin/sh\necho \"$@\" > " + argsFile + "\necho '" + envelope + "'\n" - require.NoError(t, os.WriteFile(script, []byte(content), 0755)) - - ex := iexec.New(iexec.Config{ - ClaudeBinary: script, - SystemPrompt: "sys", - Timeout: 5 * time.Second, - }) - - _, err := ex.Run(context.Background(), iexec.Request{ - SkillPrompt: "review rules", - TaskPrompt: "do review", - Model: "ollama/devstral", - }) - require.NoError(t, err) - - argsData, err := os.ReadFile(argsFile) - require.NoError(t, err) - assert.NotContains(t, string(argsData), "--model") -} diff --git a/internal/exec/litellm.go b/internal/exec/litellm.go index 27d59d0..278f8eb 100644 --- a/internal/exec/litellm.go +++ b/internal/exec/litellm.go @@ -9,9 +9,8 @@ import ( "time" ) -// LiteLLMExecutor calls a LiteLLM-compatible /v1/chat/completions endpoint. -// Local models are expected to return a JSON object matching the Result schema -// as their response content — no envelope. +// LiteLLMExecutor calls a LiteLLM-compatible /v1/chat/completions endpoint +// and returns the raw assistant message text. type LiteLLMExecutor struct { baseURL string apiKey string @@ -21,9 +20,12 @@ type LiteLLMExecutor struct { // NewLiteLLM creates a LiteLLMExecutor. // timeout applies to the full HTTP round-trip per call. func NewLiteLLM(baseURL, apiKey string, timeout time.Duration) *LiteLLMExecutor { + if timeout == 0 { + timeout = 120 * time.Second + } return &LiteLLMExecutor{ - baseURL: baseURL, - apiKey: apiKey, + baseURL: baseURL, + apiKey: apiKey, httpClient: &http.Client{Timeout: timeout}, } } @@ -46,58 +48,50 @@ type litellmResponse struct { Choices []litellmChoice `json:"choices"` } -// Run dispatches req to the LiteLLM server and parses the Result from the -// assistant message content. Returns an error on network failure, non-200 -// status, or unparseable/invalid JSON — all of which the Orchestrator treats -// as automatic escalation triggers. -func (e *LiteLLMExecutor) Run(ctx context.Context, req Request) (Result, error) { +// Complete sends system+user messages to the given model and returns the raw +// assistant text along with the round-trip duration in milliseconds. +func (e *LiteLLMExecutor) Complete(ctx context.Context, model, system, user string) (string, int64, error) { body := litellmRequest{ - Model: req.Model, + Model: model, Messages: []litellmMessage{ - {Role: "system", Content: req.SkillPrompt}, - {Role: "user", Content: req.TaskPrompt}, + {Role: "system", Content: system}, + {Role: "user", Content: user}, }, } bodyBytes, err := json.Marshal(body) if err != nil { - return Result{}, fmt.Errorf("litellm: marshal request: %w", err) + return "", 0, fmt.Errorf("litellm: marshal request: %w", err) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/v1/chat/completions", bytes.NewReader(bodyBytes)) if err != nil { - return Result{}, fmt.Errorf("litellm: create request: %w", err) + return "", 0, fmt.Errorf("litellm: create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") if e.apiKey != "" { httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) } + t0 := time.Now() resp, err := e.httpClient.Do(httpReq) if err != nil { - return Result{}, fmt.Errorf("litellm: request failed: %w", err) + return "", 0, fmt.Errorf("litellm: request failed: %w", err) } defer resp.Body.Close() //nolint:errcheck + durationMs := time.Since(t0).Milliseconds() if resp.StatusCode != http.StatusOK { - return Result{}, fmt.Errorf("litellm: server returned status %d", resp.StatusCode) + return "", 0, fmt.Errorf("litellm: server returned status %d", resp.StatusCode) } var chatResp litellmResponse if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { - return Result{}, fmt.Errorf("litellm: decode response: %w", err) + return "", 0, fmt.Errorf("litellm: decode response: %w", err) } if len(chatResp.Choices) == 0 { - return Result{}, fmt.Errorf("litellm: no choices in response") + return "", 0, fmt.Errorf("litellm: no choices in response") } - content := chatResp.Choices[0].Message.Content - var result Result - if err := json.Unmarshal([]byte(content), &result); err != nil { - return Result{}, fmt.Errorf("litellm: parse result JSON: %w — content: %s", err, content) - } - if err := result.Validate(); err != nil { - return Result{}, fmt.Errorf("litellm: invalid result: %w", err) - } - return result, nil + return chatResp.Choices[0].Message.Content, durationMs, nil } diff --git a/internal/exec/litellm_test.go b/internal/exec/litellm_test.go index dd117cd..71afc96 100644 --- a/internal/exec/litellm_test.go +++ b/internal/exec/litellm_test.go @@ -13,23 +13,11 @@ import ( "github.com/stretchr/testify/require" ) -func validLiteLLMResult() iexec.Result { - return iexec.Result{ - Status: "pass", - Phase: "review", - Skill: "review", - ModelUsed: "ollama/devstral", - Message: "looks good", - } -} - -func chatResponseFor(t *testing.T, result iexec.Result) []byte { +func chatResponse(t *testing.T, content string) []byte { t.Helper() - content, err := json.Marshal(result) - require.NoError(t, err) resp := map[string]any{ "choices": []map[string]any{ - {"message": map[string]any{"role": "assistant", "content": string(content)}}, + {"message": map[string]any{"role": "assistant", "content": content}}, }, } data, err := json.Marshal(resp) @@ -37,25 +25,21 @@ func chatResponseFor(t *testing.T, result iexec.Result) []byte { return data } -func TestLiteLLMParsesValidResult(t *testing.T) { +func TestLiteLLMReturnsText(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v1/chat/completions", r.URL.Path) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write(chatResponseFor(t, validLiteLLMResult())) + _, _ = w.Write(chatResponse(t, "here is my analysis")) })) defer srv.Close() ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second) - result, err := ex.Run(context.Background(), iexec.Request{ - SkillPrompt: "review rules", - TaskPrompt: "review the code", - Model: "ollama/devstral", - }) + text, dur, err := ex.Complete(context.Background(), "ollama/devstral", "system prompt", "user prompt") require.NoError(t, err) - assert.Equal(t, "pass", result.Status) - assert.Equal(t, "review", result.Skill) + assert.Equal(t, "here is my analysis", text) + assert.GreaterOrEqual(t, dur, int64(0)) } func TestLiteLLMSendsAuthHeader(t *testing.T) { @@ -63,12 +47,12 @@ func TestLiteLLMSendsAuthHeader(t *testing.T) { assert.Equal(t, "Bearer secret", r.Header.Get("Authorization")) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write(chatResponseFor(t, validLiteLLMResult())) + _, _ = w.Write(chatResponse(t, "ok")) })) defer srv.Close() ex := iexec.NewLiteLLM(srv.URL, "secret", 5*time.Second) - _, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t", SkillPrompt: "s"}) + _, _, err := ex.Complete(context.Background(), "model", "sys", "user") require.NoError(t, err) } @@ -79,34 +63,28 @@ func TestLiteLLMErrorOnNonOKStatus(t *testing.T) { defer srv.Close() ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second) - _, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t"}) + _, _, err := ex.Complete(context.Background(), "model", "sys", "user") assert.ErrorContains(t, err, "503") } -func TestLiteLLMErrorOnUnparsableJSON(t *testing.T) { +func TestLiteLLMErrorOnEmptyChoices(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - resp := map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"role": "assistant", "content": "not json at all"}}, - }, - } - data, _ := json.Marshal(resp) - _, _ = w.Write(data) + _, _ = w.Write([]byte(`{"choices":[]}`)) })) defer srv.Close() ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second) - _, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t"}) - assert.Error(t, err) + _, _, err := ex.Complete(context.Background(), "model", "sys", "user") + assert.ErrorContains(t, err, "no choices") } func TestLiteLLMRespectsContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately + cancel() ex := iexec.NewLiteLLM("http://invalid.example.com", "", 1*time.Second) - _, err := ex.Run(ctx, iexec.Request{Model: "x", TaskPrompt: "t"}) + _, _, err := ex.Complete(ctx, "model", "sys", "user") assert.Error(t, err) } diff --git a/internal/exec/orchestrator.go b/internal/exec/orchestrator.go deleted file mode 100644 index bddda2b..0000000 --- a/internal/exec/orchestrator.go +++ /dev/null @@ -1,197 +0,0 @@ -package exec - -import ( - "context" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -// ChainEntry is one tier in an escalation chain. -type ChainEntry struct { - Model string // e.g. "ollama/phi4", "claude-sonnet-4-6" - Tier string // "local" | "subagent" | "managed" - IsCloud bool // true for claude-* models; skips verifier call -} - -// EntryFor builds a ChainEntry from a model name string. -func EntryFor(model string) ChainEntry { - cloud := strings.HasPrefix(model, "claude-") - tier := "local" - if cloud { - tier = "subagent" - } - return ChainEntry{Model: model, Tier: tier, IsCloud: cloud} -} - -// AttemptRecord captures the outcome of one tier attempt for session logging. -type AttemptRecord struct { - Model string - Tier string - DurationMs int64 - WarmStart bool - Verdict string // "accept" | "escalate" | "error" - Feedback string -} - -// VerifierFn is the interface the orchestrator uses to verify local output. -type VerifierFn interface { - Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error) -} - -// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run. -type ExecutorRunFn func(ctx context.Context, req Request) (Result, error) - -// Orchestrator walks an escalation chain, delegating generation and verification. -// It implements the ExecutorFn shape expected by skill handlers. -type Orchestrator struct { - chain []ChainEntry - localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil - cloudRun ExecutorRunFn // for cloud tiers; may be nil - verifier VerifierFn - llamaSwapURL string - attempts *[]AttemptRecord -} - -// NewOrchestrator creates an Orchestrator. -// attempts is a pointer to a slice that will be appended to on each tier attempt. -// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain. -func NewOrchestrator( - chain []ChainEntry, - localRun ExecutorRunFn, - cloudRun ExecutorRunFn, - verifier VerifierFn, - llamaSwapURL string, - attempts *[]AttemptRecord, -) *Orchestrator { - return &Orchestrator{ - chain: chain, - localRun: localRun, - cloudRun: cloudRun, - verifier: verifier, - llamaSwapURL: llamaSwapURL, - attempts: attempts, - } -} - -// Run walks the escalation chain and returns the first accepted result. -// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error). -func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) { - taskPrompt := req.TaskPrompt - - for _, entry := range o.chain { - warm := o.probeWarm(entry.Model) - start := time.Now() - - tierReq := req - tierReq.Model = entry.Model - tierReq.TaskPrompt = taskPrompt - - if entry.IsCloud { - result, genErr := o.cloudRun(ctx, tierReq) - dur := time.Since(start).Milliseconds() - verdict := "accept" - if genErr != nil { - verdict = "error" - } - o.appendAttempt(AttemptRecord{ - Model: entry.Model, - Tier: entry.Tier, - DurationMs: dur, - WarmStart: warm, - Verdict: verdict, - }) - if genErr == nil { - return result, nil - } - continue - } - - // Local tier. - result, genErr := o.localRun(ctx, tierReq) - dur := time.Since(start).Milliseconds() - - if genErr != nil { - o.appendAttempt(AttemptRecord{ - Model: entry.Model, - Tier: entry.Tier, - DurationMs: dur, - WarmStart: warm, - Verdict: "error", - Feedback: genErr.Error(), - }) - continue - } - - verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result) - if verErr != nil { - // Treat verifier failure as escalate (safe default). - o.appendAttempt(AttemptRecord{ - Model: entry.Model, - Tier: entry.Tier, - DurationMs: dur, - WarmStart: warm, - Verdict: "escalate", - Feedback: "verifier error: " + verErr.Error(), - }) - continue - } - - if verdict.Accept { - o.appendAttempt(AttemptRecord{ - Model: entry.Model, - Tier: entry.Tier, - DurationMs: dur, - WarmStart: warm, - Verdict: "accept", - }) - return result, nil - } - - o.appendAttempt(AttemptRecord{ - Model: entry.Model, - Tier: entry.Tier, - DurationMs: dur, - WarmStart: warm, - Verdict: "escalate", - Feedback: verdict.Feedback, - }) - // Inject verifier feedback into the next tier's task prompt. - taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback - } - - return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain)) -} - -func (o *Orchestrator) appendAttempt(rec AttemptRecord) { - if o.attempts != nil { - *o.attempts = append(*o.attempts, rec) - } -} - -// probeWarm checks whether the model is currently loaded in llama-swap. -// Returns false on any error or if llamaSwapURL is empty. -func (o *Orchestrator) probeWarm(model string) bool { - if o.llamaSwapURL == "" { - return false - } - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil) - if err != nil { - return false - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() //nolint:errcheck - body, err := io.ReadAll(resp.Body) - if err != nil { - return false - } - return strings.Contains(string(body), model) -} diff --git a/internal/exec/orchestrator_test.go b/internal/exec/orchestrator_test.go deleted file mode 100644 index c0e4774..0000000 --- a/internal/exec/orchestrator_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package exec_test - -import ( - "context" - "errors" - "testing" - - iexec "github.com/mathiasbq/supervisor/internal/exec" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// stubRunFn returns preset results sequentially. -type stubRunFn struct { - calls []stubCall - callIdx int -} - -type stubCall struct { - result iexec.Result - err error -} - -func (s *stubRunFn) Run(_ context.Context, _ iexec.Request) (iexec.Result, error) { - if s.callIdx >= len(s.calls) { - return iexec.Result{}, errors.New("unexpected call") - } - c := s.calls[s.callIdx] - s.callIdx++ - return c.result, c.err -} - -// stubVerifier returns preset verdicts sequentially. -type stubVerifier struct { - verdicts []iexec.Verdict - idx int -} - -func (s *stubVerifier) Verify(_ context.Context, _, _ string, _ iexec.Result) (iexec.Verdict, error) { - if s.idx >= len(s.verdicts) { - return iexec.Verdict{}, errors.New("unexpected verify call") - } - v := s.verdicts[s.idx] - s.idx++ - return v, nil -} - -func okResult(skill string) iexec.Result { - return iexec.Result{Status: "pass", Phase: "review", Skill: skill, Message: "ok", ModelUsed: "m"} -} - -func TestOrchestratorSingleLocalAccept(t *testing.T) { - local := &stubRunFn{calls: []stubCall{{result: okResult("review")}}} - verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}} - - var attempts []iexec.AttemptRecord - orch := iexec.NewOrchestrator( - []iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}}, - local.Run, nil, verifier, "", &attempts, - ) - - result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) - require.NoError(t, err) - assert.Equal(t, "pass", result.Status) - require.Len(t, attempts, 1) - assert.Equal(t, "local", attempts[0].Tier) - assert.Equal(t, "accept", attempts[0].Verdict) -} - -func TestOrchestratorEscalatesOnVerifierReject(t *testing.T) { - local := &stubRunFn{calls: []stubCall{ - {result: iexec.Result{Status: "fail", Phase: "review", Skill: "review", Message: "weak"}}, - {result: okResult("review")}, - }} - verifier := &stubVerifier{verdicts: []iexec.Verdict{ - {Accept: false, Feedback: "missing line refs"}, - {Accept: true}, - }} - - var attempts []iexec.AttemptRecord - orch := iexec.NewOrchestrator( - []iexec.ChainEntry{ - {Model: "ollama/devstral", Tier: "local", IsCloud: false}, - {Model: "ollama/gemma4", Tier: "local", IsCloud: false}, - }, - local.Run, nil, verifier, "", &attempts, - ) - - result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) - require.NoError(t, err) - assert.Equal(t, "pass", result.Status) - require.Len(t, attempts, 2) - assert.Equal(t, "escalate", attempts[0].Verdict) - assert.Equal(t, "missing line refs", attempts[0].Feedback) - assert.Equal(t, "accept", attempts[1].Verdict) -} - -func TestOrchestratorEscalatesOnLocalError(t *testing.T) { - local := &stubRunFn{calls: []stubCall{ - {err: errors.New("network failure")}, - {result: okResult("review")}, - }} - verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}} - - var attempts []iexec.AttemptRecord - orch := iexec.NewOrchestrator( - []iexec.ChainEntry{ - {Model: "ollama/devstral", Tier: "local", IsCloud: false}, - {Model: "ollama/gemma4", Tier: "local", IsCloud: false}, - }, - local.Run, nil, verifier, "", &attempts, - ) - - _, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) - require.NoError(t, err) - require.Len(t, attempts, 2) - assert.Equal(t, "error", attempts[0].Verdict) - assert.Equal(t, "accept", attempts[1].Verdict) -} - -func TestOrchestratorCloudTierSelfCertifies(t *testing.T) { - cloud := &stubRunFn{calls: []stubCall{{result: okResult("review")}}} - verifier := &stubVerifier{} // no verdicts — must not be called - - var attempts []iexec.AttemptRecord - orch := iexec.NewOrchestrator( - []iexec.ChainEntry{{Model: "claude-sonnet-4-6", Tier: "subagent", IsCloud: true}}, - nil, cloud.Run, verifier, "", &attempts, - ) - - result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) - require.NoError(t, err) - assert.Equal(t, "pass", result.Status) - require.Len(t, attempts, 1) - assert.Equal(t, "subagent", attempts[0].Tier) - assert.Equal(t, "accept", attempts[0].Verdict) - assert.Equal(t, 0, verifier.idx) // verifier never called -} - -func TestOrchestratorAllTiersExhausted(t *testing.T) { - local := &stubRunFn{calls: []stubCall{{err: errors.New("unavailable")}}} - - var attempts []iexec.AttemptRecord - orch := iexec.NewOrchestrator( - []iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}}, - local.Run, nil, &stubVerifier{}, "", &attempts, - ) - - _, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"}) - assert.ErrorContains(t, err, "all tiers exhausted") -} diff --git a/internal/exec/result.go b/internal/exec/result.go deleted file mode 100644 index 025f9fb..0000000 --- a/internal/exec/result.go +++ /dev/null @@ -1,66 +0,0 @@ -package exec - -import ( - "errors" - "strings" -) - -// Result is the structured JSON output from every supervisor invocation. -// The JSON schema constant is passed to claude via --json-schema so Claude -// validates its own output before returning. -type Result struct { - Status string `json:"status"` // pass | fail | error - Phase string `json:"phase"` // red | green | refactor | retrospective | review | debug | spec | trainer - Skill string `json:"skill"` // tdd | review | ... - FilePath string `json:"file_path"` // absolute path to generated file - RunnerOutput string `json:"runner_output"` // raw stdout+stderr from test runner - Verified bool `json:"verified"` // based on exit code, never self-report - ModelUsed string `json:"model_used"` // model name or "self" - Message string `json:"message"` // one sentence summary - Attempts []AttemptRecord `json:"attempts,omitempty"` // populated by orchestrator, not Claude -} - -var validStatuses = map[string]bool{"pass": true, "fail": true, "error": true} -var validPhases = map[string]bool{ - "red": true, - "green": true, - "refactor": true, - "retrospective": true, - "review": true, - "debug": true, - "spec": true, - "trainer": true, -} - -func (r Result) Validate() error { - var errs []string - if !validStatuses[r.Status] { - errs = append(errs, "status must be pass|fail|error, got: "+r.Status) - } - if !validPhases[r.Phase] { - errs = append(errs, "phase must be one of red|green|refactor|retrospective|review|debug|spec|trainer, got: "+r.Phase) - } - if r.Skill == "" { - errs = append(errs, "skill is required") - } - if len(errs) > 0 { - return errors.New(strings.Join(errs, "; ")) - } - return nil -} - -// Schema is passed to claude --json-schema to enforce structured output. -const Schema = `{ - "type": "object", - "required": ["status","phase","skill","file_path","runner_output","verified","model_used","message"], - "properties": { - "status": {"type": "string", "enum": ["pass","fail","error"]}, - "phase": {"type": "string"}, - "skill": {"type": "string"}, - "file_path": {"type": "string"}, - "runner_output": {"type": "string"}, - "verified": {"type": "boolean"}, - "model_used": {"type": "string"}, - "message": {"type": "string"} - } -}` diff --git a/internal/exec/result_test.go b/internal/exec/result_test.go deleted file mode 100644 index 2d94fac..0000000 --- a/internal/exec/result_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package exec_test - -import ( - "encoding/json" - "testing" - - "github.com/mathiasbq/supervisor/internal/exec" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestResultParsesValidJSON(t *testing.T) { - raw := `{ - "status": "pass", - "phase": "red", - "skill": "tdd", - "file_path": "/tmp/foo_test.go", - "runner_output": "--- FAIL: TestFoo", - "verified": true, - "model_used": "self", - "message": "test fails as expected" - }` - var r exec.Result - require.NoError(t, json.Unmarshal([]byte(raw), &r)) - assert.Equal(t, "pass", r.Status) - assert.Equal(t, "red", r.Phase) - assert.True(t, r.Verified) -} - -func TestResultValidation(t *testing.T) { - tests := []struct { - name string - result exec.Result - wantErr bool - }{ - { - name: "valid pass result", - result: exec.Result{ - Status: "pass", Phase: "red", Skill: "tdd", - FilePath: "/tmp/x_test.go", RunnerOutput: "FAIL", - Verified: true, ModelUsed: "self", Message: "ok", - }, - wantErr: false, - }, - { - name: "empty status", - result: exec.Result{Phase: "red", Skill: "tdd"}, - wantErr: true, - }, - { - name: "invalid status", - result: exec.Result{Status: "unknown", Phase: "red", Skill: "tdd"}, - wantErr: true, - }, - { - name: "invalid phase", - result: exec.Result{Status: "pass", Phase: "bad", Skill: "tdd"}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.result.Validate() - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValidateAcceptsAllPhases(t *testing.T) { - phases := []string{"red", "green", "refactor", "retrospective", "review", "debug", "spec", "trainer"} - for _, phase := range phases { - r := exec.Result{Status: "pass", Phase: phase, Skill: "test", ModelUsed: "self", Message: "ok"} - assert.NoError(t, r.Validate(), "phase %q should be valid", phase) - } -} diff --git a/internal/exec/verifier.go b/internal/exec/verifier.go deleted file mode 100644 index c915f80..0000000 --- a/internal/exec/verifier.go +++ /dev/null @@ -1,99 +0,0 @@ -package exec - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "os" - "os/exec" - "time" -) - -// Verdict is the output of a Claude verification call. -type Verdict struct { - Accept bool `json:"accept"` - Feedback string `json:"feedback"` // empty when Accept is true -} - -// Verifier runs a focused Claude call to judge local model output. -type Verifier struct { - claudeBinary string - model string - timeout time.Duration -} - -// NewVerifier creates a Verifier that calls claude with the given binary path and model. -// Empty claudeBinary defaults to "claude". Zero timeout defaults to 30s. -func NewVerifier(claudeBinary, model string, timeout time.Duration) *Verifier { - if claudeBinary == "" { - claudeBinary = "claude" - } - if timeout == 0 { - timeout = 30 * time.Second - } - return &Verifier{ - claudeBinary: claudeBinary, - model: model, - timeout: timeout, - } -} - -// Verify asks Claude whether output satisfies the skill discipline's iron laws. -// Returns Verdict{Accept: true} to accept or Verdict{Accept: false, Feedback: "..."} -// to escalate. Returns an error on subprocess failure or unparseable response. -func (v *Verifier) Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error) { - ctx, cancel := context.WithTimeout(ctx, v.timeout) - defer cancel() - - outputJSON, err := json.Marshal(output) - if err != nil { - return Verdict{}, fmt.Errorf("verifier: marshal output: %w", err) - } - - prompt := fmt.Sprintf(`You are a quality verifier for an AI supervisor system. - -Given the skill discipline, the original task, and the generated output, decide whether the output satisfies the discipline's iron laws and output contract. - -Reply with JSON only — no other text: -{"accept": true, "feedback": ""} -or -{"accept": false, "feedback": ""} - -## Skill discipline -%s - -## Original task -%s - -## Generated output -%s`, skillPrompt, taskPrompt, string(outputJSON)) - - args := []string{ - "--print", - "--permission-mode", "bypassPermissions", - } - if v.model != "" { - args = append(args, "--model", v.model) - } - args = append(args, prompt) - - cmd := exec.CommandContext(ctx, v.claudeBinary, args...) - cmd.Env = os.Environ() - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - if ctx.Err() != nil { - return Verdict{}, fmt.Errorf("verifier: timeout after %s", v.timeout) - } - return Verdict{}, fmt.Errorf("verifier: claude exited with error: %w — stderr: %s", err, stderr.String()) - } - - var verdict Verdict - if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &verdict); err != nil { - return Verdict{}, fmt.Errorf("verifier: parse verdict JSON: %w — raw: %s", err, stdout.String()) - } - return verdict, nil -} diff --git a/internal/exec/verifier_test.go b/internal/exec/verifier_test.go deleted file mode 100644 index 83c9cb8..0000000 --- a/internal/exec/verifier_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package exec_test - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "testing" - "time" - - iexec "github.com/mathiasbq/supervisor/internal/exec" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func fakeVerifierClaude(t *testing.T, verdict iexec.Verdict) string { - t.Helper() - data, err := json.Marshal(verdict) - require.NoError(t, err) - dir := t.TempDir() - script := filepath.Join(dir, "claude") - content := fmt.Sprintf("#!/bin/sh\necho '%s'\n", string(data)) - require.NoError(t, os.WriteFile(script, []byte(content), 0755)) - return script -} - -func TestVerifierAccepts(t *testing.T) { - claude := fakeVerifierClaude(t, iexec.Verdict{Accept: true, Feedback: ""}) - v := iexec.NewVerifier(claude, "claude-sonnet-4-6", 5*time.Second) - - verdict, err := v.Verify(context.Background(), "skill rules", "do the task", iexec.Result{ - Status: "pass", Phase: "review", Skill: "review", Message: "ok", - }) - require.NoError(t, err) - assert.True(t, verdict.Accept) - assert.Empty(t, verdict.Feedback) -} - -func TestVerifierEscalates(t *testing.T) { - claude := fakeVerifierClaude(t, iexec.Verdict{Accept: false, Feedback: "missing line references"}) - v := iexec.NewVerifier(claude, "claude-sonnet-4-6", 5*time.Second) - - verdict, err := v.Verify(context.Background(), "skill rules", "do the task", iexec.Result{ - Status: "pass", Phase: "review", Skill: "review", Message: "incomplete", - }) - require.NoError(t, err) - assert.False(t, verdict.Accept) - assert.Equal(t, "missing line references", verdict.Feedback) -} - -func TestVerifierErrorOnUnparsableOutput(t *testing.T) { - dir := t.TempDir() - script := filepath.Join(dir, "claude") - require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\necho 'not json'\n"), 0755)) - - v := iexec.NewVerifier(script, "claude-sonnet-4-6", 5*time.Second) - _, err := v.Verify(context.Background(), "rules", "task", iexec.Result{ - Status: "pass", Phase: "review", Skill: "review", Message: "ok", - }) - assert.Error(t, err) -} - -func TestVerifierErrorOnNonZeroExit(t *testing.T) { - dir := t.TempDir() - script := filepath.Join(dir, "claude") - require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\nexit 1\n"), 0755)) - - v := iexec.NewVerifier(script, "claude-sonnet-4-6", 5*time.Second) - _, err := v.Verify(context.Background(), "rules", "task", iexec.Result{ - Status: "pass", Phase: "review", Skill: "review", Message: "ok", - }) - assert.Error(t, err) -} diff --git a/internal/session/attempts.go b/internal/session/attempts.go deleted file mode 100644 index 7a79cf5..0000000 --- a/internal/session/attempts.go +++ /dev/null @@ -1,26 +0,0 @@ -// internal/session/attempts.go -package session - -import iexec "github.com/mathiasbq/supervisor/internal/exec" - -// AttemptsFrom converts exec.AttemptRecord slice to session.Attempt slice -// for writing into a session JSONL entry. -func AttemptsFrom(records []iexec.AttemptRecord) []Attempt { - if len(records) == 0 { - return nil - } - out := make([]Attempt, len(records)) - for i, r := range records { - out[i] = Attempt{ - Attempt: i + 1, - Model: r.Model, - Tier: r.Tier, - DurationMs: r.DurationMs, - WarmStart: r.WarmStart, - Verdict: r.Verdict, - Feedback: r.Feedback, - Verified: r.Verdict == "accept", - } - } - return out -} diff --git a/internal/session/attempts_test.go b/internal/session/attempts_test.go deleted file mode 100644 index 851a45e..0000000 --- a/internal/session/attempts_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package session_test - -import ( - "testing" - - "github.com/mathiasbq/supervisor/internal/exec" - "github.com/mathiasbq/supervisor/internal/session" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAttemptsFromEmpty(t *testing.T) { - result := session.AttemptsFrom(nil) - assert.Empty(t, result) -} - -func TestAttemptsFromSetsIndex(t *testing.T) { - records := []exec.AttemptRecord{ - {Model: "ollama/phi4", Tier: "local", DurationMs: 1200, WarmStart: true, Verdict: "escalate", Feedback: "too vague"}, - {Model: "claude-sonnet-4-6", Tier: "subagent", DurationMs: 3400, WarmStart: false, Verdict: "accept"}, - } - result := session.AttemptsFrom(records) - require.Len(t, result, 2) - - assert.Equal(t, 1, result[0].Attempt) - assert.Equal(t, "ollama/phi4", result[0].Model) - assert.Equal(t, "local", result[0].Tier) - assert.Equal(t, int64(1200), result[0].DurationMs) - assert.True(t, result[0].WarmStart) - assert.Equal(t, "escalate", result[0].Verdict) - assert.Equal(t, "too vague", result[0].Feedback) - assert.False(t, result[0].Verified) - - assert.Equal(t, 2, result[1].Attempt) - assert.Equal(t, "claude-sonnet-4-6", result[1].Model) - assert.True(t, result[1].Verified) -} diff --git a/internal/skills/debug/handlers.go b/internal/skills/debug/handlers.go index 344ae20..6ebf89c 100644 --- a/internal/skills/debug/handlers.go +++ b/internal/skills/debug/handlers.go @@ -8,7 +8,6 @@ import ( "time" "github.com/mathiasbq/supervisor/internal/brain" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" ) @@ -52,38 +51,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( task = brainCtx + "\n---\n\n" + task } - if s.cfg.ExecutorFn == nil { + if s.cfg.CompleteFunc == nil { return nil, fmt.Errorf("no executor configured") } t0 := time.Now() - result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ - SkillPrompt: s.cfg.SkillPrompt, - TaskPrompt: task, - Model: model, - Tools: "Read,Bash", - }) + text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task) if err != nil { return nil, err } if a.SessionID != "" && s.cfg.SessionsDir != "" { + msg := text + if len(msg) > 200 { + msg = msg[:200] + } _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ SessionID: a.SessionID, Timestamp: time.Now(), Skill: "debug", Phase: "debug", ProjectRoot: a.ProjectRoot, - Attempts: session.AttemptsFrom(result.Attempts), - FinalStatus: result.Status, - ModelUsed: result.ModelUsed, + FinalStatus: "ok", + ModelUsed: model, DurationMs: time.Since(t0).Milliseconds(), - Message: result.Message, + Message: msg, }) } - b, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("marshal result: %w", err) - } - return b, nil + return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur}) } diff --git a/internal/skills/debug/handlers_test.go b/internal/skills/debug/handlers_test.go index ddf0c4b..f7c4ebb 100644 --- a/internal/skills/debug/handlers_test.go +++ b/internal/skills/debug/handlers_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "testing" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/skills/debug" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,29 +32,22 @@ func TestDebugRequiresError(t *testing.T) { assert.ErrorContains(t, err, "error") } -func TestDebugCallsExecutor(t *testing.T) { - called := false +func TestDebugCallsCompleteFunc(t *testing.T) { var capturedTask string - fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { - called = true - capturedTask = req.TaskPrompt - return iexec.Result{ - Status: "pass", Phase: "debug", Skill: "debug", - RunnerOutput: "HYPOTHESIS 1 (likelihood: high): nil map access\nVERIFY: go test ./... → expected: panic line reference", - Verified: false, ModelUsed: "self", Message: "3 hypotheses for: panic nil pointer at foo.go:42", - }, nil + fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) { + capturedTask = user + return "HYPOTHESIS 1 (high): nil map access. Verify: go test ./...", 90, nil } - sk := debug.New(debug.Config{SkillPrompt: "debug rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) + sk := debug.New(debug.Config{SkillPrompt: "debug rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()}) out, err := sk.Handle(context.Background(), "debug", json.RawMessage( `{"project_root":"/tmp/proj","error":"panic: nil pointer dereference at foo.go:42","context":"occurs on startup"}`, )) require.NoError(t, err) - assert.True(t, called) assert.Contains(t, capturedTask, "panic: nil pointer dereference") assert.Contains(t, capturedTask, "occurs on startup") - var result iexec.Result + var result map[string]any require.NoError(t, json.Unmarshal(out, &result)) - assert.Equal(t, "debug", result.Phase) + assert.Contains(t, result["text"], "nil map access") } diff --git a/internal/skills/debug/skill.go b/internal/skills/debug/skill.go index 3f7df03..8a97ccf 100644 --- a/internal/skills/debug/skill.go +++ b/internal/skills/debug/skill.go @@ -5,20 +5,19 @@ import ( "context" "encoding/json" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/registry" ) -// ExecutorFn is the function signature for running a worker subprocess. -type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) +// CompleteFunc is the function used to call a local model. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) // Config holds dependencies for the debug skill. type Config struct { SkillPrompt string DefaultModel string - ExecutorFn ExecutorFn + CompleteFunc CompleteFunc SessionsDir string - IngestBaseURL string // optional: base URL of ingestion server for brain context + IngestBaseURL string } // Skill implements the debug MCP tool. @@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef { return []registry.ToolDef{ { Name: "debug", - Description: "Analyse an error and return 3-5 hypotheses ordered by likelihood, each with a concrete verification step.", + Description: "Consult a local model to analyse an error and return hypotheses ordered by likelihood, each with a concrete verification step.", InputSchema: schema( []string{"project_root", "error"}, map[string]any{ diff --git a/internal/skills/retrospective/handlers.go b/internal/skills/retrospective/handlers.go index bf84e2a..913c8fb 100644 --- a/internal/skills/retrospective/handlers.go +++ b/internal/skills/retrospective/handlers.go @@ -7,7 +7,6 @@ import ( "fmt" "time" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" ) @@ -34,7 +33,6 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( model = s.cfg.DefaultModel } - // Read session log entries (empty slice if no log exists yet). entries, err := session.Read(s.cfg.SessionsDir, a.SessionID) if err != nil { return nil, fmt.Errorf("read session log: %w", err) @@ -46,39 +44,33 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( } taskPrompt := fmt.Sprintf( - "SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Write structured entries to brain/raw/ via brain_write. Return JSON result when done.", + "SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Provide structured insights.", a.SessionID, string(logJSON), ) - if s.cfg.ExecutorFn == nil { + if s.cfg.CompleteFunc == nil { return nil, fmt.Errorf("no executor configured") } t0 := time.Now() - result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ - SkillPrompt: s.cfg.SkillPrompt, - TaskPrompt: taskPrompt, - Model: model, - Tools: "Bash,Read,Write", - }) + text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, taskPrompt) if err != nil { - return nil, fmt.Errorf("retrospective worker: %w", err) + return nil, fmt.Errorf("retrospective model: %w", err) } + msg := text + if len(msg) > 200 { + msg = msg[:200] + } _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ SessionID: a.SessionID, Timestamp: time.Now(), Skill: "retrospective", Phase: "retrospective", - Attempts: session.AttemptsFrom(result.Attempts), - FinalStatus: result.Status, - ModelUsed: result.ModelUsed, + FinalStatus: "ok", + ModelUsed: model, DurationMs: time.Since(t0).Milliseconds(), - Message: result.Message, + Message: msg, }) - b, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("marshal result: %w", err) - } - return b, nil + return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur}) } diff --git a/internal/skills/retrospective/handlers_test.go b/internal/skills/retrospective/handlers_test.go index f7bca54..4842ba2 100644 --- a/internal/skills/retrospective/handlers_test.go +++ b/internal/skills/retrospective/handlers_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "testing" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/skills/retrospective" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,20 +19,14 @@ func TestHandle_Retrospective_RequiresSessionID(t *testing.T) { } func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) { - var capturedReq iexec.Request + var capturedTask string s := retrospective.New(retrospective.Config{ SkillPrompt: "retrospective discipline", DefaultModel: "ollama/test", - SessionsDir: t.TempDir(), // empty dir, no session file — that's OK, session.Read returns nil - ExecutorFn: func(_ context.Context, req iexec.Request) (iexec.Result, error) { - capturedReq = req - return iexec.Result{ - Status: "pass", - Phase: "retrospective", - Skill: "retrospective", - Verified: true, - Message: "wrote 2 entries to brain", - }, nil + SessionsDir: t.TempDir(), + CompleteFunc: func(_ context.Context, _, _, user string) (string, int64, error) { + capturedTask = user + return "Key insight: the team resolved a tricky nil pointer issue via careful logging.", 75, nil }, }) @@ -41,9 +34,8 @@ func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) { out, err := s.Handle(context.Background(), "retrospective", args) require.NoError(t, err) - var result iexec.Result + var result map[string]any require.NoError(t, json.Unmarshal(out, &result)) - assert.Equal(t, "pass", result.Status) - assert.Contains(t, capturedReq.SkillPrompt, "retrospective discipline") - assert.Contains(t, capturedReq.TaskPrompt, "empty-session") + assert.Contains(t, result["text"], "nil pointer") + assert.Contains(t, capturedTask, "empty-session") } diff --git a/internal/skills/retrospective/skill.go b/internal/skills/retrospective/skill.go index 6c8b582..5106742 100644 --- a/internal/skills/retrospective/skill.go +++ b/internal/skills/retrospective/skill.go @@ -5,19 +5,18 @@ import ( "context" "encoding/json" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/registry" ) -// ExecutorFn allows injecting a test double for the subprocess executor. -type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) +// CompleteFunc is the function used to call a local model. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) // Config holds retrospective skill configuration. type Config struct { - SkillPrompt string // content of retrospective.md - DefaultModel string // model to use when not specified in args - SessionsDir string // path to brain/sessions/ - ExecutorFn ExecutorFn // injected executor + SkillPrompt string + DefaultModel string + SessionsDir string + CompleteFunc CompleteFunc } // Skill implements registry.Skill for the retrospective tool. @@ -36,7 +35,7 @@ func (s *Skill) Tools() []registry.ToolDef { return []registry.ToolDef{ { Name: "retrospective", - Description: "Run a retrospective on a completed session. Reads the session log, identifies novel learnings, and writes structured entries to the brain for ingestion. Call at the end of each coding session.", + Description: "Consult a local model to analyse a completed session and identify what is novel or worth preserving as organizational knowledge.", InputSchema: json.RawMessage(`{ "type": "object", "required": ["session_id"], diff --git a/internal/skills/review/handlers.go b/internal/skills/review/handlers.go index ecabbc7..2e0d701 100644 --- a/internal/skills/review/handlers.go +++ b/internal/skills/review/handlers.go @@ -9,7 +9,6 @@ import ( "time" "github.com/mathiasbq/supervisor/internal/brain" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" ) @@ -53,39 +52,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( task = brainCtx + "\n---\n\n" + task } - if s.cfg.ExecutorFn == nil { + if s.cfg.CompleteFunc == nil { return nil, fmt.Errorf("no executor configured") } t0 := time.Now() - result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ - SkillPrompt: s.cfg.SkillPrompt, - TaskPrompt: task, - Model: model, - Tools: "Read,Bash", - }) + text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task) if err != nil { return nil, err } if a.SessionID != "" && s.cfg.SessionsDir != "" { + msg := text + if len(msg) > 200 { + msg = msg[:200] + } _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ SessionID: a.SessionID, Timestamp: time.Now(), Skill: "review", Phase: "review", ProjectRoot: a.ProjectRoot, - Attempts: session.AttemptsFrom(result.Attempts), - FinalStatus: result.Status, - FilePath: result.FilePath, - ModelUsed: result.ModelUsed, + FinalStatus: "ok", + ModelUsed: model, DurationMs: time.Since(t0).Milliseconds(), - Message: result.Message, + Message: msg, }) } - b, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("marshal result: %w", err) - } - return b, nil + return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur}) } diff --git a/internal/skills/review/handlers_test.go b/internal/skills/review/handlers_test.go index 2d32397..67ffeb7 100644 --- a/internal/skills/review/handlers_test.go +++ b/internal/skills/review/handlers_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "testing" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/skills/review" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,29 +32,22 @@ func TestReviewRequiresFiles(t *testing.T) { assert.ErrorContains(t, err, "files") } -func TestReviewCallsExecutor(t *testing.T) { - called := false +func TestReviewCallsCompleteFunc(t *testing.T) { var capturedTask string - fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { - called = true - capturedTask = req.TaskPrompt - return iexec.Result{ - Status: "pass", Phase: "review", Skill: "review", - Verified: true, ModelUsed: "self", Message: "2 warnings found", - }, nil + fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) { + capturedTask = user + return "2 warnings found: missing error handling at line 42", 80, nil } - sk := review.New(review.Config{SkillPrompt: "review rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) + sk := review.New(review.Config{SkillPrompt: "review rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()}) out, err := sk.Handle(context.Background(), "review", json.RawMessage( `{"project_root":"/tmp/proj","files":["internal/foo/foo.go"],"context":"PR: add Foo helper"}`, )) require.NoError(t, err) - assert.True(t, called) assert.Contains(t, capturedTask, "internal/foo/foo.go") assert.Contains(t, capturedTask, "PR: add Foo helper") - var result iexec.Result + var result map[string]any require.NoError(t, json.Unmarshal(out, &result)) - assert.Equal(t, "pass", result.Status) - assert.Equal(t, "review", result.Phase) + assert.Contains(t, result["text"], "2 warnings found") } diff --git a/internal/skills/review/skill.go b/internal/skills/review/skill.go index 8c309f0..361c666 100644 --- a/internal/skills/review/skill.go +++ b/internal/skills/review/skill.go @@ -5,20 +5,19 @@ import ( "context" "encoding/json" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/registry" ) -// ExecutorFn is the function signature for running a worker subprocess. -type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) +// CompleteFunc is the function used to call a local model. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) // Config holds dependencies for the review skill. type Config struct { SkillPrompt string DefaultModel string - ExecutorFn ExecutorFn + CompleteFunc CompleteFunc SessionsDir string - IngestBaseURL string // optional: base URL of ingestion server for brain context + IngestBaseURL string } // Skill implements the review MCP tool. @@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef { return []registry.ToolDef{ { Name: "review", - Description: "Perform a structured code review of the specified files. Returns findings with severity levels.", + Description: "Consult a local model for a structured code review of the specified files. Returns findings with severity levels.", InputSchema: schema( []string{"project_root", "files"}, map[string]any{ diff --git a/internal/skills/spec/handlers.go b/internal/skills/spec/handlers.go index 0514917..471afd8 100644 --- a/internal/skills/spec/handlers.go +++ b/internal/skills/spec/handlers.go @@ -8,7 +8,6 @@ import ( "time" "github.com/mathiasbq/supervisor/internal/brain" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" ) @@ -57,39 +56,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( task = brainCtx + "\n---\n\n" + task } - if s.cfg.ExecutorFn == nil { + if s.cfg.CompleteFunc == nil { return nil, fmt.Errorf("no executor configured") } t0 := time.Now() - result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ - SkillPrompt: s.cfg.SkillPrompt, - TaskPrompt: task, - Model: model, - Tools: "Read,Write", - }) + text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task) if err != nil { return nil, err } if a.SessionID != "" && s.cfg.SessionsDir != "" { + msg := text + if len(msg) > 200 { + msg = msg[:200] + } _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ SessionID: a.SessionID, Timestamp: time.Now(), Skill: "spec", Phase: "spec", ProjectRoot: a.ProjectRoot, - Attempts: session.AttemptsFrom(result.Attempts), - FinalStatus: result.Status, - FilePath: result.FilePath, - ModelUsed: result.ModelUsed, + FinalStatus: "ok", + ModelUsed: model, DurationMs: time.Since(t0).Milliseconds(), - Message: result.Message, + Message: msg, }) } - b, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("marshal result: %w", err) - } - return b, nil + return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur}) } diff --git a/internal/skills/spec/handlers_test.go b/internal/skills/spec/handlers_test.go index 6ccf6c4..3e864d2 100644 --- a/internal/skills/spec/handlers_test.go +++ b/internal/skills/spec/handlers_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "testing" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/skills/spec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,29 +32,22 @@ func TestSpecRequiresRequirements(t *testing.T) { assert.ErrorContains(t, err, "requirements") } -func TestSpecCallsExecutor(t *testing.T) { - called := false +func TestSpecCallsCompleteFunc(t *testing.T) { var capturedTask string - fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { - called = true - capturedTask = req.TaskPrompt - return iexec.Result{ - Status: "pass", Phase: "spec", Skill: "spec", - FilePath: "/tmp/proj/docs/login-spec.md", - Verified: true, ModelUsed: "self", Message: "spec written: login feature", - }, nil + fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) { + capturedTask = user + return "# OAuth2 Login Spec\n\n## Overview\nImplement OAuth2 login flow.", 110, nil } - sk := spec.New(spec.Config{SkillPrompt: "spec rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) + sk := spec.New(spec.Config{SkillPrompt: "spec rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()}) out, err := sk.Handle(context.Background(), "spec", json.RawMessage( `{"project_root":"/tmp/proj","requirements":"add OAuth2 login","output_path":"docs/login-spec.md"}`, )) require.NoError(t, err) - assert.True(t, called) assert.Contains(t, capturedTask, "OAuth2 login") assert.Contains(t, capturedTask, "docs/login-spec.md") - var result iexec.Result + var result map[string]any require.NoError(t, json.Unmarshal(out, &result)) - assert.Equal(t, "spec", result.Phase) + assert.Contains(t, result["text"], "OAuth2 Login Spec") } diff --git a/internal/skills/spec/skill.go b/internal/skills/spec/skill.go index 75a749e..461b886 100644 --- a/internal/skills/spec/skill.go +++ b/internal/skills/spec/skill.go @@ -5,20 +5,19 @@ import ( "context" "encoding/json" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/registry" ) -// ExecutorFn is the function signature for running a worker subprocess. -type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) +// CompleteFunc is the function used to call a local model. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) // Config holds dependencies for the spec skill. type Config struct { SkillPrompt string DefaultModel string - ExecutorFn ExecutorFn + CompleteFunc CompleteFunc SessionsDir string - IngestBaseURL string // optional: base URL of ingestion server for brain context + IngestBaseURL string } // Skill implements the spec MCP tool. @@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef { return []registry.ToolDef{ { Name: "spec", - Description: "Generate a structured implementation spec from requirements. Writes the spec to output_path in the project.", + Description: "Consult a local model to draft a structured implementation spec from requirements. Returns the spec text.", InputSchema: schema( []string{"project_root", "requirements"}, map[string]any{ diff --git a/internal/skills/tdd/handlers.go b/internal/skills/tdd/handlers.go index 3ed1437..f897fc6 100644 --- a/internal/skills/tdd/handlers.go +++ b/internal/skills/tdd/handlers.go @@ -7,7 +7,6 @@ import ( "time" "github.com/mathiasbq/supervisor/internal/brain" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" ) @@ -51,7 +50,7 @@ func (s *Skill) handleRed(ctx context.Context, raw json.RawMessage) (json.RawMes if brainCtx != "" { task = brainCtx + "\n---\n\n" + task } - return s.execute(ctx, task) + return s.complete(ctx, s.resolveModel(args.Model), task) } type greenArgs struct { @@ -80,11 +79,11 @@ func (s *Skill) handleGreen(ctx context.Context, raw json.RawMessage) (json.RawM task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "green", task) t0 := time.Now() - result, err := s.execute(ctx, task) + result, err := s.complete(ctx, s.resolveModel(args.Model), task) if err != nil { return nil, err } - s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "green", t0, result) + s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "green", s.resolveModel(args.Model), t0, result) return result, nil } @@ -118,11 +117,11 @@ func (s *Skill) handleRefactor(ctx context.Context, raw json.RawMessage) (json.R task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "refactor", task) t0 := time.Now() - result, err := s.execute(ctx, task) + result, err := s.complete(ctx, s.resolveModel(args.Model), task) if err != nil { return nil, err } - s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "refactor", t0, result) + s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "refactor", s.resolveModel(args.Model), t0, result) return result, nil } @@ -133,31 +132,32 @@ func (s *Skill) resolveModel(override string) string { return s.cfg.DefaultModel } -// execute calls ExecutorFn and returns the marshaled result. -func (s *Skill) execute(ctx context.Context, task string) (json.RawMessage, error) { - if s.cfg.ExecutorFn == nil { +// complete calls CompleteFunc and returns the text as JSON. +func (s *Skill) complete(ctx context.Context, model, task string) (json.RawMessage, error) { + if s.cfg.CompleteFunc == nil { return nil, fmt.Errorf("no executor configured") } - req := iexec.Request{ - SkillPrompt: s.cfg.SkillPrompt, - TaskPrompt: task, - } - result, err := s.cfg.ExecutorFn(ctx, req) + text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task) if err != nil { return nil, err } - return json.Marshal(result) + return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur}) } -// logAttempt writes a session.Entry for a completed phase if session_id is set. -// raw is the marshaled Result returned by execute; we unmarshal to extract fields. -func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.Time, raw json.RawMessage) { +// logEntry writes a session.Entry for a completed phase if session_id is set. +func (s *Skill) logEntry(sessionID, projectRoot, skill, phase, model string, t0 time.Time, raw json.RawMessage) { if sessionID == "" || s.cfg.SessionsDir == "" { return } - var result iexec.Result - if err := json.Unmarshal(raw, &result); err != nil { - return + var msg string + var result struct { + Text string `json:"text"` + } + if err := json.Unmarshal(raw, &result); err == nil && len(result.Text) > 0 { + msg = result.Text + if len(msg) > 200 { + msg = msg[:200] + } } _ = session.Append(s.cfg.SessionsDir, sessionID, session.Entry{ SessionID: sessionID, @@ -165,11 +165,9 @@ func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time. Skill: skill, Phase: phase, ProjectRoot: projectRoot, - Attempts: session.AttemptsFrom(result.Attempts), - FinalStatus: result.Status, - FilePath: result.FilePath, - ModelUsed: result.ModelUsed, + FinalStatus: "ok", + ModelUsed: model, DurationMs: time.Since(t0).Milliseconds(), - Message: result.Message, + Message: msg, }) } diff --git a/internal/skills/tdd/handlers_test.go b/internal/skills/tdd/handlers_test.go index e299cbb..ab0f1d5 100644 --- a/internal/skills/tdd/handlers_test.go +++ b/internal/skills/tdd/handlers_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "testing" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/skills/tdd" "github.com/stretchr/testify/assert" @@ -14,8 +13,7 @@ import ( func TestTDDSkillTools(t *testing.T) { skill := tdd.New(tdd.Config{ - SystemPrompt: "supervisor rules", - SkillPrompt: "tdd rules", + SkillPrompt: "tdd rules", }) tools := skill.Tools() names := make([]string, len(tools)) @@ -26,19 +24,19 @@ func TestTDDSkillTools(t *testing.T) { } func TestTDDSkillHandleUnknown(t *testing.T) { - skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) + skill := tdd.New(tdd.Config{SkillPrompt: "t"}) _, err := skill.Handle(context.Background(), "tdd_unknown", json.RawMessage(`{}`)) assert.ErrorContains(t, err, "unknown tool") } func TestTDDRedRequiresProjectRoot(t *testing.T) { - skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) + skill := tdd.New(tdd.Config{SkillPrompt: "t"}) _, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"spec":"add two numbers"}`)) assert.ErrorContains(t, err, "project_root") } func TestTDDRedRequiresSpec(t *testing.T) { - skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) + skill := tdd.New(tdd.Config{SkillPrompt: "t"}) _, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"project_root":"/tmp/proj"}`)) assert.ErrorContains(t, err, "spec") } @@ -51,35 +49,49 @@ func TestTDDGreenInjectsSessionHistory(t *testing.T) { Message: "wrote failing test for Foo", })) - var capturedPrompt string - fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { - capturedPrompt = req.TaskPrompt - return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil + var capturedTask string + fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) { + capturedTask = user + return "here is my suggestion", 100, nil } - sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: sessDir}) + sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: sessDir}) _, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage( `{"project_root":"/tmp","test_path":"internal/foo/foo_test.go","test_cmd":"go test ./...","session_id":"sess-1"}`, )) require.NoError(t, err) - assert.Contains(t, capturedPrompt, "## Session history") - assert.Contains(t, capturedPrompt, "wrote failing test for Foo") + assert.Contains(t, capturedTask, "## Session history") + assert.Contains(t, capturedTask, "wrote failing test for Foo") } func TestTDDGreenNoHistoryWhenSessionIDEmpty(t *testing.T) { - var capturedPrompt string - fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { - capturedPrompt = req.TaskPrompt - return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil + var capturedTask string + fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) { + capturedTask = user + return "suggestion", 50, nil } - sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) + sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: t.TempDir()}) _, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage( `{"project_root":"/tmp","test_path":"internal/foo/foo_test.go"}`, )) require.NoError(t, err) - assert.NotContains(t, capturedPrompt, "## Session history") + assert.NotContains(t, capturedTask, "## Session history") } -// Ensure require is used (avoids import error). -var _ = require.New +func TestTDDGreenReturnsTextJSON(t *testing.T) { + fakeFn := func(_ context.Context, _, _, _ string) (string, int64, error) { + return "write a func that adds two ints", 42, nil + } + + sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn}) + raw, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage( + `{"project_root":"/tmp","test_path":"foo_test.go"}`, + )) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(raw, &result)) + assert.Equal(t, "write a func that adds two ints", result["text"]) + assert.Equal(t, float64(42), result["duration_ms"]) +} diff --git a/internal/skills/tdd/skill.go b/internal/skills/tdd/skill.go index a7f5f96..5d99caf 100644 --- a/internal/skills/tdd/skill.go +++ b/internal/skills/tdd/skill.go @@ -4,17 +4,15 @@ import ( "context" "encoding/json" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/registry" ) -// ExecutorFn allows injecting a test double for the executor. -type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) +// CompleteFunc is the function used to call a local model. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) type Config struct { - SystemPrompt string SkillPrompt string - ExecutorFn ExecutorFn // nil = no executor (tests that don't reach execute()) + CompleteFunc CompleteFunc // nil = no executor (tests that don't reach execute()) DefaultModel string SessionsDir string // optional: path to brain/sessions/ for history injection IngestBaseURL string // optional: base URL of ingestion server for brain context @@ -44,7 +42,7 @@ func (s *Skill) Tools() []registry.ToolDef { return []registry.ToolDef{ { Name: "tdd_red", - Description: "Write a failing test for the described behavior. Verifies the test fails before returning.", + Description: "Consult a local model for help writing a failing test for the described behavior.", InputSchema: schema( []string{"project_root", "spec"}, map[string]any{ @@ -57,7 +55,7 @@ func (s *Skill) Tools() []registry.ToolDef { }, { Name: "tdd_green", - Description: "Write minimal implementation to make the test at test_path pass.", + Description: "Consult a local model for implementation ideas to make the test at test_path pass.", InputSchema: schema( []string{"project_root", "test_path"}, map[string]any{ @@ -71,7 +69,7 @@ func (s *Skill) Tools() []registry.ToolDef { }, { Name: "tdd_refactor", - Description: "Refactor the implementation at impl_path while keeping tests green.", + Description: "Consult a local model for refactoring suggestions for impl_path while keeping tests green.", InputSchema: schema( []string{"project_root", "test_path", "impl_path"}, map[string]any{ diff --git a/internal/skills/trainer/handlers.go b/internal/skills/trainer/handlers.go index 50921fe..71c85a9 100644 --- a/internal/skills/trainer/handlers.go +++ b/internal/skills/trainer/handlers.go @@ -7,7 +7,6 @@ import ( "fmt" "time" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" ) @@ -28,7 +27,7 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( if a.SessionID == "" { return nil, fmt.Errorf("session_id is required") } - if s.cfg.ExecutorFn == nil { + if s.cfg.CompleteFunc == nil { return nil, fmt.Errorf("no executor configured") } @@ -42,53 +41,47 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) ( return nil, fmt.Errorf("read session log: %w", err) } - // ── Step 1: Reader agent ───────────────────────────────────────────────── + // ── Step 1: Reader ──────────────────────────────────────────────────────── history := session.FormatHistory(entries, "") readerTask := fmt.Sprintf( "role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s", a.SessionID, s.cfg.BrainDir, history, ) - readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{ - SkillPrompt: s.cfg.ReaderPrompt, - TaskPrompt: readerTask, - Model: model, - Tools: "Read", - }) + readerText, _, err := s.cfg.CompleteFunc(ctx, model, s.cfg.ReaderPrompt, readerTask) if err != nil { - return nil, fmt.Errorf("reader agent: %w", err) + return nil, fmt.Errorf("reader: %w", err) } - // ── Step 2: Writer agent (receives reader candidates) ──────────────────── + // ── Step 2: Writer (receives reader output) ─────────────────────────────── t0 := time.Now() writerTask := fmt.Sprintf( - "role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s", - a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput, + "role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_analysis:\n%s", + a.SessionID, s.cfg.BrainDir, readerText, ) - writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{ - SkillPrompt: s.cfg.WriterPrompt, - TaskPrompt: writerTask, - Model: model, - Tools: "Read,Write", - }) + writerText, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.WriterPrompt, writerTask) if err != nil { - return nil, fmt.Errorf("writer agent: %w", err) + return nil, fmt.Errorf("writer: %w", err) } + msg := writerText + if len(msg) > 200 { + msg = msg[:200] + } _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ SessionID: a.SessionID, Timestamp: time.Now(), Skill: "trainer", Phase: "trainer", - Attempts: session.AttemptsFrom(writerResult.Attempts), - FinalStatus: writerResult.Status, - ModelUsed: writerResult.ModelUsed, + FinalStatus: "ok", + ModelUsed: model, DurationMs: time.Since(t0).Milliseconds(), - Message: writerResult.Message, + Message: msg, }) - b, err := json.Marshal(writerResult) - if err != nil { - return nil, fmt.Errorf("marshal result: %w", err) - } - return b, nil + return json.Marshal(map[string]any{ + "reader_analysis": readerText, + "writer_output": writerText, + "model": model, + "duration_ms": dur, + }) } diff --git a/internal/skills/trainer/handlers_test.go b/internal/skills/trainer/handlers_test.go index e20704b..a9370aa 100644 --- a/internal/skills/trainer/handlers_test.go +++ b/internal/skills/trainer/handlers_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "testing" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/skills/trainer" "github.com/stretchr/testify/assert" @@ -31,52 +30,44 @@ func TestTrainerRequiresSessionID(t *testing.T) { func TestTrainerCallsReaderThenWriter(t *testing.T) { sessDir := t.TempDir() require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{ - SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass", + SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "ok", Message: "wrote failing test", FilePath: "internal/foo/foo_test.go", })) callCount := 0 var readerTask, writerTask string - fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { + fakeFn := func(_ context.Context, _, sys, user string) (string, int64, error) { callCount++ if callCount == 1 { // reader call - readerTask = req.TaskPrompt - return iexec.Result{ - Status: "pass", Phase: "trainer", Skill: "trainer", - RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`, - Verified: true, ModelUsed: "self", Message: "1 sft candidate found", - }, nil + readerTask = user + return "1 sft candidate found: first-pass clean TDD", 60, nil } // writer call - writerTask = req.TaskPrompt - return iexec.Result{ - Status: "pass", Phase: "trainer", Skill: "trainer", - FilePath: sessDir + "/training-data/sft/sess-1.jsonl", - Verified: true, ModelUsed: "self", Message: "1 sft pair written", - }, nil + writerTask = user + return "written 1 knowledge entry to brain/knowledge/tdd-patterns.md", 70, nil } sk := trainer.New(trainer.Config{ ReaderPrompt: "reader rules", WriterPrompt: "writer rules", - ExecutorFn: fakeFn, + CompleteFunc: fakeFn, SessionsDir: sessDir, BrainDir: t.TempDir(), }) out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`)) require.NoError(t, err) - assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer") + assert.Equal(t, 2, callCount, "complete must be called exactly twice: reader then writer") assert.Contains(t, readerTask, "role: reader") assert.Contains(t, readerTask, "sess-1") - assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt + assert.Contains(t, readerTask, "wrote failing test") assert.Contains(t, writerTask, "role: writer") - assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer + assert.Contains(t, writerTask, "sft candidate") - var result iexec.Result + var result map[string]any require.NoError(t, json.Unmarshal(out, &result)) - assert.Equal(t, "trainer", result.Phase) - assert.Equal(t, "pass", result.Status) + assert.Contains(t, result["reader_analysis"], "sft candidate") + assert.Contains(t, result["writer_output"], "knowledge entry") } diff --git a/internal/skills/trainer/skill.go b/internal/skills/trainer/skill.go index d5ecf86..f37164e 100644 --- a/internal/skills/trainer/skill.go +++ b/internal/skills/trainer/skill.go @@ -5,21 +5,20 @@ import ( "context" "encoding/json" - iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/registry" ) -// ExecutorFn is the function signature for running a worker subprocess. -type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) +// CompleteFunc is the function used to call a local model. +type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error) // Config holds dependencies for the trainer skill. type Config struct { ReaderPrompt string WriterPrompt string DefaultModel string - ExecutorFn ExecutorFn + CompleteFunc CompleteFunc SessionsDir string - BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/ + BrainDir string // root of brain/ directory } // Skill implements the trainer MCP tool. @@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef { return []registry.ToolDef{ { Name: "trainer", - Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.", + Description: "Consult a local model to identify learning moments from a session log and suggest knowledge to preserve in the brain.", InputSchema: schema( []string{"session_id"}, map[string]any{