refactor: replace orchestrator/verifier chain with direct LiteLLM calls
All checks were successful
cd / Build and deploy (push) Successful in 6s
CI / Lint / Test / Vet (push) Successful in 10s
CI / Mirror to GitHub (push) Successful in 3s

Drop the three-layer Claude subprocess orchestration (local model →
Claude verifier → cloud escalation). Skills now call LiteLLM directly
and return plain text to Claude Code, which decides what to do with it.

- Delete executor, orchestrator, verifier, result, attempts packages
- Simplify LiteLLMExecutor: Run(Request)→Result becomes Complete(model,sys,user)→(string,int64,error)
- Replace ExecutorFn with CompleteFunc in all 6 skill configs
- Rewrite all skill handlers to call Complete and return {"text","model","duration_ms"}
- Simplify config/models: remove Verifier/LlamaSwapURL, add ModelFor
- Bump version to v0.5.0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mathias Bergqvist
2026-04-22 16:19:09 +02:00
parent 823de23213
commit ce45592730
34 changed files with 266 additions and 1432 deletions

View File

@@ -37,12 +37,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
systemPrompt, err := os.ReadFile(cfg.ConfigDir + "/CLAUDE.md")
if err != nil {
logger.Error("read supervisor CLAUDE.md", "path", cfg.ConfigDir+"/CLAUDE.md", "err", err)
os.Exit(1)
}
protocolsPrompt, err := os.ReadFile(cfg.ConfigDir + "/protocols.md") protocolsPrompt, err := os.ReadFile(cfg.ConfigDir + "/protocols.md")
if err != nil { if err != nil {
logger.Error("read protocols.md", "path", cfg.ConfigDir+"/protocols.md", "err", err) logger.Error("read protocols.md", "path", cfg.ConfigDir+"/protocols.md", "err", err)
@@ -95,40 +89,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
claudeExec := iexec.New(iexec.Config{ litellm := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0)
SystemPrompt: string(systemPrompt),
LiteLLMBaseURL: cfg.LiteLLMBaseURL,
LiteLLMAPIKey: cfg.LiteLLMAPIKey,
})
litellmExec := iexec.NewLiteLLM(cfg.LiteLLMBaseURL, cfg.LiteLLMAPIKey, 0)
verifier := iexec.NewVerifier("", models.Verifier(), 0)
buildOrch := func(skill string) func(ctx context.Context, req iexec.Request) (iexec.Result, error) {
return func(ctx context.Context, req iexec.Request) (iexec.Result, error) {
rawChain := models.ChainFor(skill, req.Model)
chain := make([]iexec.ChainEntry, len(rawChain))
for i, m := range rawChain {
chain[i] = iexec.EntryFor(m)
}
attempts := make([]iexec.AttemptRecord, 0, len(chain))
orch := iexec.NewOrchestrator(chain, litellmExec.Run, claudeExec.Run, verifier, models.LlamaSwapURL(), &attempts)
result, err := orch.Run(ctx, req)
result.Attempts = attempts // attach orchestration metadata before returning
// Log per-attempt verdicts so pass rates are visible in pod logs.
for i, a := range attempts {
logger.Info("chain attempt",
"skill", skill,
"attempt", i+1,
"model", a.Model,
"tier", a.Tier,
"verdict", a.Verdict,
"duration_ms", a.DurationMs,
"warm", a.WarmStart,
)
}
return result, err
}
}
tierFn := func(ctx context.Context) tier.Info { tierFn := func(ctx context.Context) tier.Info {
return tier.Detect(ctx, "https://api.anthropic.com", cfg.LiteLLMBaseURL) return tier.Detect(ctx, "https://api.anthropic.com", cfg.LiteLLMBaseURL)
@@ -136,10 +97,9 @@ func main() {
reg := registry.New() reg := registry.New()
reg.Register(tdd.New(tdd.Config{ reg.Register(tdd.New(tdd.Config{
SystemPrompt: string(systemPrompt),
SkillPrompt: prependProtocols(tddPrompt), SkillPrompt: prependProtocols(tddPrompt),
DefaultModel: models.ChainFor("tdd", "")[0], DefaultModel: models.ModelFor("tdd", ""),
ExecutorFn: buildOrch("tdd"), CompleteFunc: litellm.Complete,
SessionsDir: cfg.SessionsDir, SessionsDir: cfg.SessionsDir,
IngestBaseURL: cfg.IngestBaseURL, IngestBaseURL: cfg.IngestBaseURL,
})) }))
@@ -154,36 +114,36 @@ func main() {
})) }))
reg.Register(retrospective.New(retrospective.Config{ reg.Register(retrospective.New(retrospective.Config{
SkillPrompt: prependProtocols(retroPrompt), SkillPrompt: prependProtocols(retroPrompt),
DefaultModel: models.ChainFor("retrospective", "")[0], DefaultModel: models.ModelFor("retrospective", ""),
SessionsDir: cfg.SessionsDir, SessionsDir: cfg.SessionsDir,
ExecutorFn: buildOrch("retrospective"), CompleteFunc: litellm.Complete,
})) }))
reg.Register(review.New(review.Config{ reg.Register(review.New(review.Config{
SkillPrompt: prependProtocols(reviewPrompt), SkillPrompt: prependProtocols(reviewPrompt),
DefaultModel: models.ChainFor("review", "")[0], DefaultModel: models.ModelFor("review", ""),
ExecutorFn: buildOrch("review"), CompleteFunc: litellm.Complete,
SessionsDir: cfg.SessionsDir, SessionsDir: cfg.SessionsDir,
IngestBaseURL: cfg.IngestBaseURL, IngestBaseURL: cfg.IngestBaseURL,
})) }))
reg.Register(skilldebug.New(skilldebug.Config{ reg.Register(skilldebug.New(skilldebug.Config{
SkillPrompt: prependProtocols(debugPrompt), SkillPrompt: prependProtocols(debugPrompt),
DefaultModel: models.ChainFor("debug", "")[0], DefaultModel: models.ModelFor("debug", ""),
ExecutorFn: buildOrch("debug"), CompleteFunc: litellm.Complete,
SessionsDir: cfg.SessionsDir, SessionsDir: cfg.SessionsDir,
IngestBaseURL: cfg.IngestBaseURL, IngestBaseURL: cfg.IngestBaseURL,
})) }))
reg.Register(spec.New(spec.Config{ reg.Register(spec.New(spec.Config{
SkillPrompt: prependProtocols(specPrompt), SkillPrompt: prependProtocols(specPrompt),
DefaultModel: models.ChainFor("spec", "")[0], DefaultModel: models.ModelFor("spec", ""),
ExecutorFn: buildOrch("spec"), CompleteFunc: litellm.Complete,
SessionsDir: cfg.SessionsDir, SessionsDir: cfg.SessionsDir,
IngestBaseURL: cfg.IngestBaseURL, IngestBaseURL: cfg.IngestBaseURL,
})) }))
reg.Register(trainer.New(trainer.Config{ reg.Register(trainer.New(trainer.Config{
ReaderPrompt: prependProtocols(trainerReaderPrompt), ReaderPrompt: prependProtocols(trainerReaderPrompt),
WriterPrompt: prependProtocols(trainerWriterPrompt), WriterPrompt: prependProtocols(trainerWriterPrompt),
DefaultModel: models.ChainFor("trainer", "")[0], DefaultModel: models.ModelFor("trainer", ""),
ExecutorFn: buildOrch("trainer"), CompleteFunc: litellm.Complete,
SessionsDir: cfg.SessionsDir, SessionsDir: cfg.SessionsDir,
BrainDir: cfg.BrainDir, BrainDir: cfg.BrainDir,
})) }))
@@ -193,7 +153,7 @@ func main() {
mux.Handle("/mcp", srv) mux.Handle("/mcp", srv)
addr := ":" + cfg.Port addr := ":" + cfg.Port
logger.Info("supervisor starting", "addr", addr, "version", "v0.4.0") logger.Info("supervisor starting", "addr", addr, "version", "v0.5.0")
if err := http.ListenAndServe(addr, mux); err != nil { if err := http.ListenAndServe(addr, mux); err != nil {
logger.Error("server stopped", "err", err) logger.Error("server stopped", "err", err)
os.Exit(1) os.Exit(1)

View File

@@ -1,41 +1,25 @@
# Model routing chains — three-layer priority: # Model selection — first entry per skill is used.
# 1. model param in MCP tool call (caller override — collapses to single entry, no escalation) # Override per-call by passing model in the MCP tool args.
# 2. per-skill chain here
# 3. default_chain fallback
verifier: claude-sonnet-4-6 # fixed verifier for all local tiers
llama_swap_url: http://koala:8080 # for warm-state probing
default_chain: default_chain:
- ollama/qwen3-coder-30b-tuned - ollama/qwen3-coder-30b-tuned
- claude-sonnet-4-6
skills: skills:
tdd: tdd:
chain: chain:
- ollama/qwen3-coder-30b-tuned - ollama/qwen3-coder-30b-tuned
- claude-sonnet-4-6
review: review:
chain: chain:
- ollama/devstral-tuned - ollama/devstral-tuned
- ollama/gemma4
- claude-sonnet-4-6
debug: debug:
chain: chain:
- ollama/deepseek-r1-tuned - ollama/deepseek-r1-tuned
- claude-sonnet-4-6
spec: spec:
chain: chain:
- ollama/phi4 - ollama/phi4
- ollama/gemma4
- claude-sonnet-4-6
- claude-opus-4-6
retrospective: retrospective:
chain: chain:
- ollama/qwen3-coder-30b-tuned - ollama/qwen3-coder-30b-tuned
- claude-sonnet-4-6
trainer: trainer:
chain: chain:
- ollama/qwen3-coder-30b-tuned - ollama/qwen3-coder-30b-tuned
- claude-sonnet-4-6

View File

@@ -12,8 +12,6 @@ type skillChain struct {
} }
type modelsFile struct { type modelsFile struct {
Verifier string `yaml:"verifier"`
LlamaSwapURL string `yaml:"llama_swap_url"`
DefaultChain []string `yaml:"default_chain"` DefaultChain []string `yaml:"default_chain"`
Skills map[string]skillChain `yaml:"skills"` Skills map[string]skillChain `yaml:"skills"`
} }
@@ -34,23 +32,18 @@ func LoadModels(path string) (Models, error) {
return Models{data: f}, nil return Models{data: f}, nil
} }
// Verifier returns the model name to use for all local-tier output verification. // ModelFor returns the primary model to use for a skill.
func (m Models) Verifier() string { return m.data.Verifier } // If override is non-empty, it is returned directly.
// Falls back to default_chain[0] when the skill has no explicit entry.
// LlamaSwapURL returns the llama-swap base URL for warm-state probing. func (m Models) ModelFor(skill, override string) string {
func (m Models) LlamaSwapURL() string { return m.data.LlamaSwapURL }
// ChainFor returns the ordered list of model names for a skill.
// If override is non-empty, returns a single-entry chain (no escalation).
// Falls back to default_chain when the skill has no explicit entry.
func (m Models) ChainFor(skill, override string) []string {
if override != "" { if override != "" {
return []string{override} return override
} }
if sc, ok := m.data.Skills[skill]; ok && len(sc.Chain) > 0 { if sc, ok := m.data.Skills[skill]; ok && len(sc.Chain) > 0 {
return sc.Chain return sc.Chain[0]
} }
out := make([]string, len(m.data.DefaultChain)) if len(m.data.DefaultChain) > 0 {
copy(out, m.data.DefaultChain) return m.data.DefaultChain[0]
return out }
return ""
} }

View File

@@ -11,9 +11,6 @@ import (
) )
const testYAML = ` const testYAML = `
verifier: claude-sonnet-4-6
llama_swap_url: http://koala:8080
default_chain: default_chain:
- ollama/qwen3-coder-30b-tuned - ollama/qwen3-coder-30b-tuned
- claude-sonnet-4-6 - claude-sonnet-4-6
@@ -37,44 +34,20 @@ func writeModels(t *testing.T, content string) string {
return f return f
} }
func TestModelsVerifier(t *testing.T) { func TestModelsModelForSkillWithEntry(t *testing.T) {
m, err := config.LoadModels(writeModels(t, testYAML)) m, err := config.LoadModels(writeModels(t, testYAML))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "claude-sonnet-4-6", m.Verifier()) assert.Equal(t, "ollama/devstral-tuned", m.ModelFor("review", ""))
} }
func TestModelsLlamaSwapURL(t *testing.T) { func TestModelsModelForDefaultFallback(t *testing.T) {
m, err := config.LoadModels(writeModels(t, testYAML)) m, err := config.LoadModels(writeModels(t, testYAML))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "http://koala:8080", m.LlamaSwapURL()) assert.Equal(t, "ollama/qwen3-coder-30b-tuned", m.ModelFor("trainer", ""))
} }
func TestModelsChainForSkillOverride(t *testing.T) { func TestModelsModelForCallerOverride(t *testing.T) {
m, err := config.LoadModels(writeModels(t, testYAML)) m, err := config.LoadModels(writeModels(t, testYAML))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "claude-opus-4-6", m.ModelFor("review", "claude-opus-4-6"))
chain := m.ChainFor("review", "")
require.Len(t, chain, 3)
assert.Equal(t, "ollama/devstral-tuned", chain[0])
assert.Equal(t, "ollama/gemma4", chain[1])
assert.Equal(t, "claude-sonnet-4-6", chain[2])
}
func TestModelsChainForDefaultFallback(t *testing.T) {
m, err := config.LoadModels(writeModels(t, testYAML))
require.NoError(t, err)
chain := m.ChainFor("trainer", "") // not in skills map
require.Len(t, chain, 2)
assert.Equal(t, "ollama/qwen3-coder-30b-tuned", chain[0])
assert.Equal(t, "claude-sonnet-4-6", chain[1])
}
func TestModelsChainForCallerOverride(t *testing.T) {
m, err := config.LoadModels(writeModels(t, testYAML))
require.NoError(t, err)
chain := m.ChainFor("review", "claude-opus-4-6")
require.Len(t, chain, 1)
assert.Equal(t, "claude-opus-4-6", chain[0])
} }

View File

@@ -1,111 +0,0 @@
package exec
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"strings"
"time"
)
// Config holds executor configuration.
type Config struct {
ClaudeBinary string // path to claude binary, defaults to "claude"
SystemPrompt string // contents of supervisor CLAUDE.md
Timeout time.Duration // per-invocation timeout, default 120s
LiteLLMBaseURL string // passed to Claude so it can delegate to Ollama
LiteLLMAPIKey string // passed to Claude for LiteLLM auth
}
// Request is the input to a single supervisor invocation.
type Request struct {
SkillPrompt string // skill-specific discipline (e.g. tdd.md contents)
TaskPrompt string // the specific task (phase, project_root, spec, model)
Model string // resolved model name, passed in task prompt
Tools string // comma-separated allowed tools, default "Bash,Read,Write"
}
// Executor spawns a claude instance and captures its structured JSON output.
type Executor struct {
cfg Config
}
func New(cfg Config) *Executor {
if cfg.ClaudeBinary == "" {
cfg.ClaudeBinary = "claude"
}
if cfg.Timeout == 0 {
cfg.Timeout = 120 * time.Second
}
return &Executor{cfg: cfg}
}
func (e *Executor) Run(ctx context.Context, req Request) (Result, error) {
ctx, cancel := context.WithTimeout(ctx, e.cfg.Timeout)
defer cancel()
tools := req.Tools
if tools == "" {
tools = "Bash,Read,Write"
}
// Build the full prompt: system rules + skill rules + infra context + task.
// LITELLM_API_KEY is injected as a subprocess env var, not in the prompt,
// to prevent it appearing in error log output.
litellmCtx := fmt.Sprintf("LITELLM_BASE_URL: %s", e.cfg.LiteLLMBaseURL)
prompt := strings.Join([]string{
e.cfg.SystemPrompt,
"---",
req.SkillPrompt,
"---",
litellmCtx,
"---",
req.TaskPrompt,
}, "\n\n")
args := []string{
"--print",
"--permission-mode", "bypassPermissions",
"--tools", tools,
"--json-schema", Schema,
"--output-format", "json",
}
if strings.HasPrefix(req.Model, "claude-") {
args = append(args, "--model", req.Model)
}
args = append(args, prompt)
cmd := exec.CommandContext(ctx, e.cfg.ClaudeBinary, args...)
cmd.Env = append(os.Environ(), "LITELLM_API_KEY="+e.cfg.LiteLLMAPIKey)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if ctx.Err() != nil {
return Result{}, fmt.Errorf("timeout after %s", e.cfg.Timeout)
}
return Result{}, fmt.Errorf("claude exited with error: %w — stderr: %s", err, stderr.String())
}
// --output-format json wraps the response in an envelope; structured output
// from --json-schema is in the "structured_output" field.
var envelope struct {
StructuredOutput *Result `json:"structured_output"`
IsError bool `json:"is_error"`
Result string `json:"result"` // fallback text result for error messages
}
if err := json.Unmarshal(stdout.Bytes(), &envelope); err != nil {
return Result{}, fmt.Errorf("parse envelope JSON: %w — raw: %s — stderr: %s", err, stdout.String(), stderr.String())
}
if envelope.StructuredOutput == nil {
return Result{}, fmt.Errorf("no structured_output in response — result: %s — stderr: %s", envelope.Result, stderr.String())
}
if err := envelope.StructuredOutput.Validate(); err != nil {
return Result{}, fmt.Errorf("invalid result: %w", err)
}
return *envelope.StructuredOutput, nil
}

View File

@@ -1,132 +0,0 @@
package exec_test
import (
"context"
"os"
"path/filepath"
"testing"
"time"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeClaudePath writes a shell script that prints fixed output and returns its path.
func fakeClaudePath(t *testing.T, output string, exitCode int) string {
t.Helper()
dir := t.TempDir()
script := filepath.Join(dir, "claude")
var content string
if exitCode != 0 {
content = "#!/bin/sh\necho 'error' >&2\nexit 1\n"
} else {
content = "#!/bin/sh\necho '" + output + "'\n"
}
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
return script
}
func TestExecutorParsesValidResult(t *testing.T) {
// Fake claude emits the --output-format json envelope that the real CLI produces.
// The executor extracts the result from the "structured_output" field.
envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"red","skill":"tdd","file_path":"/tmp/x_test.go","runner_output":"FAIL","verified":true,"model_used":"self","message":"ok"}}`
claude := fakeClaudePath(t, envelope, 0)
ex := iexec.New(iexec.Config{
ClaudeBinary: claude,
SystemPrompt: "you are a supervisor",
Timeout: 5 * time.Second,
})
result, err := ex.Run(context.Background(), iexec.Request{
SkillPrompt: "tdd rules",
TaskPrompt: "run red phase",
})
require.NoError(t, err)
assert.Equal(t, "pass", result.Status)
assert.True(t, result.Verified)
}
func TestExecutorReturnsErrorOnNonZeroExit(t *testing.T) {
claude := fakeClaudePath(t, "", 1)
ex := iexec.New(iexec.Config{
ClaudeBinary: claude,
SystemPrompt: "you are a supervisor",
Timeout: 5 * time.Second,
})
_, err := ex.Run(context.Background(), iexec.Request{TaskPrompt: "fail"})
assert.Error(t, err)
}
func TestExecutorTimesOut(t *testing.T) {
dir := t.TempDir()
script := filepath.Join(dir, "claude")
require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\nsleep 60\n"), 0755))
ex := iexec.New(iexec.Config{
ClaudeBinary: script,
SystemPrompt: "you are a supervisor",
Timeout: 100 * time.Millisecond,
})
_, err := ex.Run(context.Background(), iexec.Request{TaskPrompt: "slow"})
assert.ErrorContains(t, err, "timeout")
}
func TestExecutorPassesModelFlagForCloudModel(t *testing.T) {
// The script captures its args to a temp file so we can assert --model was passed.
argsFile := filepath.Join(t.TempDir(), "args.txt")
envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"review","skill":"review","file_path":"","runner_output":"","verified":true,"model_used":"claude-sonnet-4-6","message":"ok"}}`
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := "#!/bin/sh\necho \"$@\" > " + argsFile + "\necho '" + envelope + "'\n"
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
ex := iexec.New(iexec.Config{
ClaudeBinary: script,
SystemPrompt: "sys",
Timeout: 5 * time.Second,
})
_, err := ex.Run(context.Background(), iexec.Request{
SkillPrompt: "review rules",
TaskPrompt: "do review",
Model: "claude-sonnet-4-6",
})
require.NoError(t, err)
argsData, err := os.ReadFile(argsFile)
require.NoError(t, err)
assert.Contains(t, string(argsData), "--model claude-sonnet-4-6")
}
func TestExecutorSkipsModelFlagForLocalModel(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
envelope := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"status":"pass","phase":"review","skill":"review","file_path":"","runner_output":"","verified":true,"model_used":"ollama/devstral","message":"ok"}}`
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := "#!/bin/sh\necho \"$@\" > " + argsFile + "\necho '" + envelope + "'\n"
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
ex := iexec.New(iexec.Config{
ClaudeBinary: script,
SystemPrompt: "sys",
Timeout: 5 * time.Second,
})
_, err := ex.Run(context.Background(), iexec.Request{
SkillPrompt: "review rules",
TaskPrompt: "do review",
Model: "ollama/devstral",
})
require.NoError(t, err)
argsData, err := os.ReadFile(argsFile)
require.NoError(t, err)
assert.NotContains(t, string(argsData), "--model")
}

View File

@@ -9,9 +9,8 @@ import (
"time" "time"
) )
// LiteLLMExecutor calls a LiteLLM-compatible /v1/chat/completions endpoint. // LiteLLMExecutor calls a LiteLLM-compatible /v1/chat/completions endpoint
// Local models are expected to return a JSON object matching the Result schema // and returns the raw assistant message text.
// as their response content — no envelope.
type LiteLLMExecutor struct { type LiteLLMExecutor struct {
baseURL string baseURL string
apiKey string apiKey string
@@ -21,6 +20,9 @@ type LiteLLMExecutor struct {
// NewLiteLLM creates a LiteLLMExecutor. // NewLiteLLM creates a LiteLLMExecutor.
// timeout applies to the full HTTP round-trip per call. // timeout applies to the full HTTP round-trip per call.
func NewLiteLLM(baseURL, apiKey string, timeout time.Duration) *LiteLLMExecutor { func NewLiteLLM(baseURL, apiKey string, timeout time.Duration) *LiteLLMExecutor {
if timeout == 0 {
timeout = 120 * time.Second
}
return &LiteLLMExecutor{ return &LiteLLMExecutor{
baseURL: baseURL, baseURL: baseURL,
apiKey: apiKey, apiKey: apiKey,
@@ -46,58 +48,50 @@ type litellmResponse struct {
Choices []litellmChoice `json:"choices"` Choices []litellmChoice `json:"choices"`
} }
// Run dispatches req to the LiteLLM server and parses the Result from the // Complete sends system+user messages to the given model and returns the raw
// assistant message content. Returns an error on network failure, non-200 // assistant text along with the round-trip duration in milliseconds.
// status, or unparseable/invalid JSON — all of which the Orchestrator treats func (e *LiteLLMExecutor) Complete(ctx context.Context, model, system, user string) (string, int64, error) {
// as automatic escalation triggers.
func (e *LiteLLMExecutor) Run(ctx context.Context, req Request) (Result, error) {
body := litellmRequest{ body := litellmRequest{
Model: req.Model, Model: model,
Messages: []litellmMessage{ Messages: []litellmMessage{
{Role: "system", Content: req.SkillPrompt}, {Role: "system", Content: system},
{Role: "user", Content: req.TaskPrompt}, {Role: "user", Content: user},
}, },
} }
bodyBytes, err := json.Marshal(body) bodyBytes, err := json.Marshal(body)
if err != nil { if err != nil {
return Result{}, fmt.Errorf("litellm: marshal request: %w", err) return "", 0, fmt.Errorf("litellm: marshal request: %w", err)
} }
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/v1/chat/completions", bytes.NewReader(bodyBytes)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/v1/chat/completions", bytes.NewReader(bodyBytes))
if err != nil { if err != nil {
return Result{}, fmt.Errorf("litellm: create request: %w", err) return "", 0, fmt.Errorf("litellm: create request: %w", err)
} }
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
if e.apiKey != "" { if e.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) httpReq.Header.Set("Authorization", "Bearer "+e.apiKey)
} }
t0 := time.Now()
resp, err := e.httpClient.Do(httpReq) resp, err := e.httpClient.Do(httpReq)
if err != nil { if err != nil {
return Result{}, fmt.Errorf("litellm: request failed: %w", err) return "", 0, fmt.Errorf("litellm: request failed: %w", err)
} }
defer resp.Body.Close() //nolint:errcheck defer resp.Body.Close() //nolint:errcheck
durationMs := time.Since(t0).Milliseconds()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return Result{}, fmt.Errorf("litellm: server returned status %d", resp.StatusCode) return "", 0, fmt.Errorf("litellm: server returned status %d", resp.StatusCode)
} }
var chatResp litellmResponse var chatResp litellmResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return Result{}, fmt.Errorf("litellm: decode response: %w", err) return "", 0, fmt.Errorf("litellm: decode response: %w", err)
} }
if len(chatResp.Choices) == 0 { if len(chatResp.Choices) == 0 {
return Result{}, fmt.Errorf("litellm: no choices in response") return "", 0, fmt.Errorf("litellm: no choices in response")
} }
content := chatResp.Choices[0].Message.Content return chatResp.Choices[0].Message.Content, durationMs, nil
var result Result
if err := json.Unmarshal([]byte(content), &result); err != nil {
return Result{}, fmt.Errorf("litellm: parse result JSON: %w — content: %s", err, content)
}
if err := result.Validate(); err != nil {
return Result{}, fmt.Errorf("litellm: invalid result: %w", err)
}
return result, nil
} }

View File

@@ -13,23 +13,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func validLiteLLMResult() iexec.Result { func chatResponse(t *testing.T, content string) []byte {
return iexec.Result{
Status: "pass",
Phase: "review",
Skill: "review",
ModelUsed: "ollama/devstral",
Message: "looks good",
}
}
func chatResponseFor(t *testing.T, result iexec.Result) []byte {
t.Helper() t.Helper()
content, err := json.Marshal(result)
require.NoError(t, err)
resp := map[string]any{ resp := map[string]any{
"choices": []map[string]any{ "choices": []map[string]any{
{"message": map[string]any{"role": "assistant", "content": string(content)}}, {"message": map[string]any{"role": "assistant", "content": content}},
}, },
} }
data, err := json.Marshal(resp) data, err := json.Marshal(resp)
@@ -37,25 +25,21 @@ func chatResponseFor(t *testing.T, result iexec.Result) []byte {
return data return data
} }
func TestLiteLLMParsesValidResult(t *testing.T) { func TestLiteLLMReturnsText(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/chat/completions", r.URL.Path) assert.Equal(t, "/v1/chat/completions", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(chatResponseFor(t, validLiteLLMResult())) _, _ = w.Write(chatResponse(t, "here is my analysis"))
})) }))
defer srv.Close() defer srv.Close()
ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second) ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second)
result, err := ex.Run(context.Background(), iexec.Request{ text, dur, err := ex.Complete(context.Background(), "ollama/devstral", "system prompt", "user prompt")
SkillPrompt: "review rules",
TaskPrompt: "review the code",
Model: "ollama/devstral",
})
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "pass", result.Status) assert.Equal(t, "here is my analysis", text)
assert.Equal(t, "review", result.Skill) assert.GreaterOrEqual(t, dur, int64(0))
} }
func TestLiteLLMSendsAuthHeader(t *testing.T) { func TestLiteLLMSendsAuthHeader(t *testing.T) {
@@ -63,12 +47,12 @@ func TestLiteLLMSendsAuthHeader(t *testing.T) {
assert.Equal(t, "Bearer secret", r.Header.Get("Authorization")) assert.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(chatResponseFor(t, validLiteLLMResult())) _, _ = w.Write(chatResponse(t, "ok"))
})) }))
defer srv.Close() defer srv.Close()
ex := iexec.NewLiteLLM(srv.URL, "secret", 5*time.Second) ex := iexec.NewLiteLLM(srv.URL, "secret", 5*time.Second)
_, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t", SkillPrompt: "s"}) _, _, err := ex.Complete(context.Background(), "model", "sys", "user")
require.NoError(t, err) require.NoError(t, err)
} }
@@ -79,34 +63,28 @@ func TestLiteLLMErrorOnNonOKStatus(t *testing.T) {
defer srv.Close() defer srv.Close()
ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second) ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second)
_, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t"}) _, _, err := ex.Complete(context.Background(), "model", "sys", "user")
assert.ErrorContains(t, err, "503") assert.ErrorContains(t, err, "503")
} }
func TestLiteLLMErrorOnUnparsableJSON(t *testing.T) { func TestLiteLLMErrorOnEmptyChoices(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
resp := map[string]any{ _, _ = w.Write([]byte(`{"choices":[]}`))
"choices": []map[string]any{
{"message": map[string]any{"role": "assistant", "content": "not json at all"}},
},
}
data, _ := json.Marshal(resp)
_, _ = w.Write(data)
})) }))
defer srv.Close() defer srv.Close()
ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second) ex := iexec.NewLiteLLM(srv.URL, "", 5*time.Second)
_, err := ex.Run(context.Background(), iexec.Request{Model: "x", TaskPrompt: "t"}) _, _, err := ex.Complete(context.Background(), "model", "sys", "user")
assert.Error(t, err) assert.ErrorContains(t, err, "no choices")
} }
func TestLiteLLMRespectsContextCancellation(t *testing.T) { func TestLiteLLMRespectsContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately cancel()
ex := iexec.NewLiteLLM("http://invalid.example.com", "", 1*time.Second) ex := iexec.NewLiteLLM("http://invalid.example.com", "", 1*time.Second)
_, err := ex.Run(ctx, iexec.Request{Model: "x", TaskPrompt: "t"}) _, _, err := ex.Complete(ctx, "model", "sys", "user")
assert.Error(t, err) assert.Error(t, err)
} }

View File

@@ -1,197 +0,0 @@
package exec
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ChainEntry is one tier in an escalation chain.
type ChainEntry struct {
Model string // e.g. "ollama/phi4", "claude-sonnet-4-6"
Tier string // "local" | "subagent" | "managed"
IsCloud bool // true for claude-* models; skips verifier call
}
// EntryFor builds a ChainEntry from a model name string.
func EntryFor(model string) ChainEntry {
cloud := strings.HasPrefix(model, "claude-")
tier := "local"
if cloud {
tier = "subagent"
}
return ChainEntry{Model: model, Tier: tier, IsCloud: cloud}
}
// AttemptRecord captures the outcome of one tier attempt for session logging.
type AttemptRecord struct {
Model string
Tier string
DurationMs int64
WarmStart bool
Verdict string // "accept" | "escalate" | "error"
Feedback string
}
// VerifierFn is the interface the orchestrator uses to verify local output.
type VerifierFn interface {
Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error)
}
// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run.
type ExecutorRunFn func(ctx context.Context, req Request) (Result, error)
// Orchestrator walks an escalation chain, delegating generation and verification.
// It implements the ExecutorFn shape expected by skill handlers.
type Orchestrator struct {
chain []ChainEntry
localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil
cloudRun ExecutorRunFn // for cloud tiers; may be nil
verifier VerifierFn
llamaSwapURL string
attempts *[]AttemptRecord
}
// NewOrchestrator creates an Orchestrator.
// attempts is a pointer to a slice that will be appended to on each tier attempt.
// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain.
func NewOrchestrator(
chain []ChainEntry,
localRun ExecutorRunFn,
cloudRun ExecutorRunFn,
verifier VerifierFn,
llamaSwapURL string,
attempts *[]AttemptRecord,
) *Orchestrator {
return &Orchestrator{
chain: chain,
localRun: localRun,
cloudRun: cloudRun,
verifier: verifier,
llamaSwapURL: llamaSwapURL,
attempts: attempts,
}
}
// Run walks the escalation chain and returns the first accepted result.
// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error).
func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) {
taskPrompt := req.TaskPrompt
for _, entry := range o.chain {
warm := o.probeWarm(entry.Model)
start := time.Now()
tierReq := req
tierReq.Model = entry.Model
tierReq.TaskPrompt = taskPrompt
if entry.IsCloud {
result, genErr := o.cloudRun(ctx, tierReq)
dur := time.Since(start).Milliseconds()
verdict := "accept"
if genErr != nil {
verdict = "error"
}
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: verdict,
})
if genErr == nil {
return result, nil
}
continue
}
// Local tier.
result, genErr := o.localRun(ctx, tierReq)
dur := time.Since(start).Milliseconds()
if genErr != nil {
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "error",
Feedback: genErr.Error(),
})
continue
}
verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result)
if verErr != nil {
// Treat verifier failure as escalate (safe default).
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "escalate",
Feedback: "verifier error: " + verErr.Error(),
})
continue
}
if verdict.Accept {
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "accept",
})
return result, nil
}
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "escalate",
Feedback: verdict.Feedback,
})
// Inject verifier feedback into the next tier's task prompt.
taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback
}
return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain))
}
func (o *Orchestrator) appendAttempt(rec AttemptRecord) {
if o.attempts != nil {
*o.attempts = append(*o.attempts, rec)
}
}
// probeWarm checks whether the model is currently loaded in llama-swap.
// Returns false on any error or if llamaSwapURL is empty.
func (o *Orchestrator) probeWarm(model string) bool {
if o.llamaSwapURL == "" {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil)
if err != nil {
return false
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false
}
defer resp.Body.Close() //nolint:errcheck
body, err := io.ReadAll(resp.Body)
if err != nil {
return false
}
return strings.Contains(string(body), model)
}

View File

@@ -1,151 +0,0 @@
package exec_test
import (
"context"
"errors"
"testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubRunFn returns preset results sequentially.
type stubRunFn struct {
calls []stubCall
callIdx int
}
type stubCall struct {
result iexec.Result
err error
}
func (s *stubRunFn) Run(_ context.Context, _ iexec.Request) (iexec.Result, error) {
if s.callIdx >= len(s.calls) {
return iexec.Result{}, errors.New("unexpected call")
}
c := s.calls[s.callIdx]
s.callIdx++
return c.result, c.err
}
// stubVerifier returns preset verdicts sequentially.
type stubVerifier struct {
verdicts []iexec.Verdict
idx int
}
func (s *stubVerifier) Verify(_ context.Context, _, _ string, _ iexec.Result) (iexec.Verdict, error) {
if s.idx >= len(s.verdicts) {
return iexec.Verdict{}, errors.New("unexpected verify call")
}
v := s.verdicts[s.idx]
s.idx++
return v, nil
}
func okResult(skill string) iexec.Result {
return iexec.Result{Status: "pass", Phase: "review", Skill: skill, Message: "ok", ModelUsed: "m"}
}
func TestOrchestratorSingleLocalAccept(t *testing.T) {
local := &stubRunFn{calls: []stubCall{{result: okResult("review")}}}
verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}}
var attempts []iexec.AttemptRecord
orch := iexec.NewOrchestrator(
[]iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}},
local.Run, nil, verifier, "", &attempts,
)
result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
require.NoError(t, err)
assert.Equal(t, "pass", result.Status)
require.Len(t, attempts, 1)
assert.Equal(t, "local", attempts[0].Tier)
assert.Equal(t, "accept", attempts[0].Verdict)
}
func TestOrchestratorEscalatesOnVerifierReject(t *testing.T) {
local := &stubRunFn{calls: []stubCall{
{result: iexec.Result{Status: "fail", Phase: "review", Skill: "review", Message: "weak"}},
{result: okResult("review")},
}}
verifier := &stubVerifier{verdicts: []iexec.Verdict{
{Accept: false, Feedback: "missing line refs"},
{Accept: true},
}}
var attempts []iexec.AttemptRecord
orch := iexec.NewOrchestrator(
[]iexec.ChainEntry{
{Model: "ollama/devstral", Tier: "local", IsCloud: false},
{Model: "ollama/gemma4", Tier: "local", IsCloud: false},
},
local.Run, nil, verifier, "", &attempts,
)
result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
require.NoError(t, err)
assert.Equal(t, "pass", result.Status)
require.Len(t, attempts, 2)
assert.Equal(t, "escalate", attempts[0].Verdict)
assert.Equal(t, "missing line refs", attempts[0].Feedback)
assert.Equal(t, "accept", attempts[1].Verdict)
}
func TestOrchestratorEscalatesOnLocalError(t *testing.T) {
local := &stubRunFn{calls: []stubCall{
{err: errors.New("network failure")},
{result: okResult("review")},
}}
verifier := &stubVerifier{verdicts: []iexec.Verdict{{Accept: true}}}
var attempts []iexec.AttemptRecord
orch := iexec.NewOrchestrator(
[]iexec.ChainEntry{
{Model: "ollama/devstral", Tier: "local", IsCloud: false},
{Model: "ollama/gemma4", Tier: "local", IsCloud: false},
},
local.Run, nil, verifier, "", &attempts,
)
_, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
require.NoError(t, err)
require.Len(t, attempts, 2)
assert.Equal(t, "error", attempts[0].Verdict)
assert.Equal(t, "accept", attempts[1].Verdict)
}
func TestOrchestratorCloudTierSelfCertifies(t *testing.T) {
cloud := &stubRunFn{calls: []stubCall{{result: okResult("review")}}}
verifier := &stubVerifier{} // no verdicts — must not be called
var attempts []iexec.AttemptRecord
orch := iexec.NewOrchestrator(
[]iexec.ChainEntry{{Model: "claude-sonnet-4-6", Tier: "subagent", IsCloud: true}},
nil, cloud.Run, verifier, "", &attempts,
)
result, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
require.NoError(t, err)
assert.Equal(t, "pass", result.Status)
require.Len(t, attempts, 1)
assert.Equal(t, "subagent", attempts[0].Tier)
assert.Equal(t, "accept", attempts[0].Verdict)
assert.Equal(t, 0, verifier.idx) // verifier never called
}
func TestOrchestratorAllTiersExhausted(t *testing.T) {
local := &stubRunFn{calls: []stubCall{{err: errors.New("unavailable")}}}
var attempts []iexec.AttemptRecord
orch := iexec.NewOrchestrator(
[]iexec.ChainEntry{{Model: "ollama/devstral", Tier: "local", IsCloud: false}},
local.Run, nil, &stubVerifier{}, "", &attempts,
)
_, err := orch.Run(context.Background(), iexec.Request{TaskPrompt: "review"})
assert.ErrorContains(t, err, "all tiers exhausted")
}

View File

@@ -1,66 +0,0 @@
package exec
import (
"errors"
"strings"
)
// Result is the structured JSON output from every supervisor invocation.
// The JSON schema constant is passed to claude via --json-schema so Claude
// validates its own output before returning.
type Result struct {
Status string `json:"status"` // pass | fail | error
Phase string `json:"phase"` // red | green | refactor | retrospective | review | debug | spec | trainer
Skill string `json:"skill"` // tdd | review | ...
FilePath string `json:"file_path"` // absolute path to generated file
RunnerOutput string `json:"runner_output"` // raw stdout+stderr from test runner
Verified bool `json:"verified"` // based on exit code, never self-report
ModelUsed string `json:"model_used"` // model name or "self"
Message string `json:"message"` // one sentence summary
Attempts []AttemptRecord `json:"attempts,omitempty"` // populated by orchestrator, not Claude
}
var validStatuses = map[string]bool{"pass": true, "fail": true, "error": true}
var validPhases = map[string]bool{
"red": true,
"green": true,
"refactor": true,
"retrospective": true,
"review": true,
"debug": true,
"spec": true,
"trainer": true,
}
func (r Result) Validate() error {
var errs []string
if !validStatuses[r.Status] {
errs = append(errs, "status must be pass|fail|error, got: "+r.Status)
}
if !validPhases[r.Phase] {
errs = append(errs, "phase must be one of red|green|refactor|retrospective|review|debug|spec|trainer, got: "+r.Phase)
}
if r.Skill == "" {
errs = append(errs, "skill is required")
}
if len(errs) > 0 {
return errors.New(strings.Join(errs, "; "))
}
return nil
}
// Schema is passed to claude --json-schema to enforce structured output.
const Schema = `{
"type": "object",
"required": ["status","phase","skill","file_path","runner_output","verified","model_used","message"],
"properties": {
"status": {"type": "string", "enum": ["pass","fail","error"]},
"phase": {"type": "string"},
"skill": {"type": "string"},
"file_path": {"type": "string"},
"runner_output": {"type": "string"},
"verified": {"type": "boolean"},
"model_used": {"type": "string"},
"message": {"type": "string"}
}
}`

View File

@@ -1,79 +0,0 @@
package exec_test
import (
"encoding/json"
"testing"
"github.com/mathiasbq/supervisor/internal/exec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestResultParsesValidJSON(t *testing.T) {
raw := `{
"status": "pass",
"phase": "red",
"skill": "tdd",
"file_path": "/tmp/foo_test.go",
"runner_output": "--- FAIL: TestFoo",
"verified": true,
"model_used": "self",
"message": "test fails as expected"
}`
var r exec.Result
require.NoError(t, json.Unmarshal([]byte(raw), &r))
assert.Equal(t, "pass", r.Status)
assert.Equal(t, "red", r.Phase)
assert.True(t, r.Verified)
}
func TestResultValidation(t *testing.T) {
tests := []struct {
name string
result exec.Result
wantErr bool
}{
{
name: "valid pass result",
result: exec.Result{
Status: "pass", Phase: "red", Skill: "tdd",
FilePath: "/tmp/x_test.go", RunnerOutput: "FAIL",
Verified: true, ModelUsed: "self", Message: "ok",
},
wantErr: false,
},
{
name: "empty status",
result: exec.Result{Phase: "red", Skill: "tdd"},
wantErr: true,
},
{
name: "invalid status",
result: exec.Result{Status: "unknown", Phase: "red", Skill: "tdd"},
wantErr: true,
},
{
name: "invalid phase",
result: exec.Result{Status: "pass", Phase: "bad", Skill: "tdd"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.result.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateAcceptsAllPhases(t *testing.T) {
phases := []string{"red", "green", "refactor", "retrospective", "review", "debug", "spec", "trainer"}
for _, phase := range phases {
r := exec.Result{Status: "pass", Phase: phase, Skill: "test", ModelUsed: "self", Message: "ok"}
assert.NoError(t, r.Validate(), "phase %q should be valid", phase)
}
}

View File

@@ -1,99 +0,0 @@
package exec
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"time"
)
// Verdict is the output of a Claude verification call.
type Verdict struct {
Accept bool `json:"accept"`
Feedback string `json:"feedback"` // empty when Accept is true
}
// Verifier runs a focused Claude call to judge local model output.
type Verifier struct {
claudeBinary string
model string
timeout time.Duration
}
// NewVerifier creates a Verifier that calls claude with the given binary path and model.
// Empty claudeBinary defaults to "claude". Zero timeout defaults to 30s.
func NewVerifier(claudeBinary, model string, timeout time.Duration) *Verifier {
if claudeBinary == "" {
claudeBinary = "claude"
}
if timeout == 0 {
timeout = 30 * time.Second
}
return &Verifier{
claudeBinary: claudeBinary,
model: model,
timeout: timeout,
}
}
// Verify asks Claude whether output satisfies the skill discipline's iron laws.
// Returns Verdict{Accept: true} to accept or Verdict{Accept: false, Feedback: "..."}
// to escalate. Returns an error on subprocess failure or unparseable response.
func (v *Verifier) Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error) {
ctx, cancel := context.WithTimeout(ctx, v.timeout)
defer cancel()
outputJSON, err := json.Marshal(output)
if err != nil {
return Verdict{}, fmt.Errorf("verifier: marshal output: %w", err)
}
prompt := fmt.Sprintf(`You are a quality verifier for an AI supervisor system.
Given the skill discipline, the original task, and the generated output, decide whether the output satisfies the discipline's iron laws and output contract.
Reply with JSON only — no other text:
{"accept": true, "feedback": ""}
or
{"accept": false, "feedback": "<one sentence reason>"}
## Skill discipline
%s
## Original task
%s
## Generated output
%s`, skillPrompt, taskPrompt, string(outputJSON))
args := []string{
"--print",
"--permission-mode", "bypassPermissions",
}
if v.model != "" {
args = append(args, "--model", v.model)
}
args = append(args, prompt)
cmd := exec.CommandContext(ctx, v.claudeBinary, args...)
cmd.Env = os.Environ()
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if ctx.Err() != nil {
return Verdict{}, fmt.Errorf("verifier: timeout after %s", v.timeout)
}
return Verdict{}, fmt.Errorf("verifier: claude exited with error: %w — stderr: %s", err, stderr.String())
}
var verdict Verdict
if err := json.Unmarshal(bytes.TrimSpace(stdout.Bytes()), &verdict); err != nil {
return Verdict{}, fmt.Errorf("verifier: parse verdict JSON: %w — raw: %s", err, stdout.String())
}
return verdict, nil
}

View File

@@ -1,74 +0,0 @@
package exec_test
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"time"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func fakeVerifierClaude(t *testing.T, verdict iexec.Verdict) string {
t.Helper()
data, err := json.Marshal(verdict)
require.NoError(t, err)
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := fmt.Sprintf("#!/bin/sh\necho '%s'\n", string(data))
require.NoError(t, os.WriteFile(script, []byte(content), 0755))
return script
}
func TestVerifierAccepts(t *testing.T) {
claude := fakeVerifierClaude(t, iexec.Verdict{Accept: true, Feedback: ""})
v := iexec.NewVerifier(claude, "claude-sonnet-4-6", 5*time.Second)
verdict, err := v.Verify(context.Background(), "skill rules", "do the task", iexec.Result{
Status: "pass", Phase: "review", Skill: "review", Message: "ok",
})
require.NoError(t, err)
assert.True(t, verdict.Accept)
assert.Empty(t, verdict.Feedback)
}
func TestVerifierEscalates(t *testing.T) {
claude := fakeVerifierClaude(t, iexec.Verdict{Accept: false, Feedback: "missing line references"})
v := iexec.NewVerifier(claude, "claude-sonnet-4-6", 5*time.Second)
verdict, err := v.Verify(context.Background(), "skill rules", "do the task", iexec.Result{
Status: "pass", Phase: "review", Skill: "review", Message: "incomplete",
})
require.NoError(t, err)
assert.False(t, verdict.Accept)
assert.Equal(t, "missing line references", verdict.Feedback)
}
func TestVerifierErrorOnUnparsableOutput(t *testing.T) {
dir := t.TempDir()
script := filepath.Join(dir, "claude")
require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\necho 'not json'\n"), 0755))
v := iexec.NewVerifier(script, "claude-sonnet-4-6", 5*time.Second)
_, err := v.Verify(context.Background(), "rules", "task", iexec.Result{
Status: "pass", Phase: "review", Skill: "review", Message: "ok",
})
assert.Error(t, err)
}
func TestVerifierErrorOnNonZeroExit(t *testing.T) {
dir := t.TempDir()
script := filepath.Join(dir, "claude")
require.NoError(t, os.WriteFile(script, []byte("#!/bin/sh\nexit 1\n"), 0755))
v := iexec.NewVerifier(script, "claude-sonnet-4-6", 5*time.Second)
_, err := v.Verify(context.Background(), "rules", "task", iexec.Result{
Status: "pass", Phase: "review", Skill: "review", Message: "ok",
})
assert.Error(t, err)
}

View File

@@ -1,26 +0,0 @@
// internal/session/attempts.go
package session
import iexec "github.com/mathiasbq/supervisor/internal/exec"
// AttemptsFrom converts exec.AttemptRecord slice to session.Attempt slice
// for writing into a session JSONL entry.
func AttemptsFrom(records []iexec.AttemptRecord) []Attempt {
if len(records) == 0 {
return nil
}
out := make([]Attempt, len(records))
for i, r := range records {
out[i] = Attempt{
Attempt: i + 1,
Model: r.Model,
Tier: r.Tier,
DurationMs: r.DurationMs,
WarmStart: r.WarmStart,
Verdict: r.Verdict,
Feedback: r.Feedback,
Verified: r.Verdict == "accept",
}
}
return out
}

View File

@@ -1,37 +0,0 @@
package session_test
import (
"testing"
"github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAttemptsFromEmpty(t *testing.T) {
result := session.AttemptsFrom(nil)
assert.Empty(t, result)
}
func TestAttemptsFromSetsIndex(t *testing.T) {
records := []exec.AttemptRecord{
{Model: "ollama/phi4", Tier: "local", DurationMs: 1200, WarmStart: true, Verdict: "escalate", Feedback: "too vague"},
{Model: "claude-sonnet-4-6", Tier: "subagent", DurationMs: 3400, WarmStart: false, Verdict: "accept"},
}
result := session.AttemptsFrom(records)
require.Len(t, result, 2)
assert.Equal(t, 1, result[0].Attempt)
assert.Equal(t, "ollama/phi4", result[0].Model)
assert.Equal(t, "local", result[0].Tier)
assert.Equal(t, int64(1200), result[0].DurationMs)
assert.True(t, result[0].WarmStart)
assert.Equal(t, "escalate", result[0].Verdict)
assert.Equal(t, "too vague", result[0].Feedback)
assert.False(t, result[0].Verified)
assert.Equal(t, 2, result[1].Attempt)
assert.Equal(t, "claude-sonnet-4-6", result[1].Model)
assert.True(t, result[1].Verified)
}

View File

@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/mathiasbq/supervisor/internal/brain" "github.com/mathiasbq/supervisor/internal/brain"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
) )
@@ -52,38 +51,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
task = brainCtx + "\n---\n\n" + task task = brainCtx + "\n---\n\n" + task
} }
if s.cfg.ExecutorFn == nil { if s.cfg.CompleteFunc == nil {
return nil, fmt.Errorf("no executor configured") return nil, fmt.Errorf("no executor configured")
} }
t0 := time.Now() t0 := time.Now()
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
SkillPrompt: s.cfg.SkillPrompt,
TaskPrompt: task,
Model: model,
Tools: "Read,Bash",
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if a.SessionID != "" && s.cfg.SessionsDir != "" { if a.SessionID != "" && s.cfg.SessionsDir != "" {
msg := text
if len(msg) > 200 {
msg = msg[:200]
}
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
SessionID: a.SessionID, SessionID: a.SessionID,
Timestamp: time.Now(), Timestamp: time.Now(),
Skill: "debug", Skill: "debug",
Phase: "debug", Phase: "debug",
ProjectRoot: a.ProjectRoot, ProjectRoot: a.ProjectRoot,
Attempts: session.AttemptsFrom(result.Attempts), FinalStatus: "ok",
FinalStatus: result.Status, ModelUsed: model,
ModelUsed: result.ModelUsed,
DurationMs: time.Since(t0).Milliseconds(), DurationMs: time.Since(t0).Milliseconds(),
Message: result.Message, Message: msg,
}) })
} }
b, err := json.Marshal(result) return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
if err != nil {
return nil, fmt.Errorf("marshal result: %w", err)
}
return b, nil
} }

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/skills/debug" "github.com/mathiasbq/supervisor/internal/skills/debug"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -33,29 +32,22 @@ func TestDebugRequiresError(t *testing.T) {
assert.ErrorContains(t, err, "error") assert.ErrorContains(t, err, "error")
} }
func TestDebugCallsExecutor(t *testing.T) { func TestDebugCallsCompleteFunc(t *testing.T) {
called := false
var capturedTask string var capturedTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
called = true capturedTask = user
capturedTask = req.TaskPrompt return "HYPOTHESIS 1 (high): nil map access. Verify: go test ./...", 90, nil
return iexec.Result{
Status: "pass", Phase: "debug", Skill: "debug",
RunnerOutput: "HYPOTHESIS 1 (likelihood: high): nil map access\nVERIFY: go test ./... → expected: panic line reference",
Verified: false, ModelUsed: "self", Message: "3 hypotheses for: panic nil pointer at foo.go:42",
}, nil
} }
sk := debug.New(debug.Config{SkillPrompt: "debug rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) sk := debug.New(debug.Config{SkillPrompt: "debug rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
out, err := sk.Handle(context.Background(), "debug", json.RawMessage( out, err := sk.Handle(context.Background(), "debug", json.RawMessage(
`{"project_root":"/tmp/proj","error":"panic: nil pointer dereference at foo.go:42","context":"occurs on startup"}`, `{"project_root":"/tmp/proj","error":"panic: nil pointer dereference at foo.go:42","context":"occurs on startup"}`,
)) ))
require.NoError(t, err) require.NoError(t, err)
assert.True(t, called)
assert.Contains(t, capturedTask, "panic: nil pointer dereference") assert.Contains(t, capturedTask, "panic: nil pointer dereference")
assert.Contains(t, capturedTask, "occurs on startup") assert.Contains(t, capturedTask, "occurs on startup")
var result iexec.Result var result map[string]any
require.NoError(t, json.Unmarshal(out, &result)) require.NoError(t, json.Unmarshal(out, &result))
assert.Equal(t, "debug", result.Phase) assert.Contains(t, result["text"], "nil map access")
} }

View File

@@ -5,20 +5,19 @@ import (
"context" "context"
"encoding/json" "encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry" "github.com/mathiasbq/supervisor/internal/registry"
) )
// ExecutorFn is the function signature for running a worker subprocess. // CompleteFunc is the function used to call a local model.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
// Config holds dependencies for the debug skill. // Config holds dependencies for the debug skill.
type Config struct { type Config struct {
SkillPrompt string SkillPrompt string
DefaultModel string DefaultModel string
ExecutorFn ExecutorFn CompleteFunc CompleteFunc
SessionsDir string SessionsDir string
IngestBaseURL string // optional: base URL of ingestion server for brain context IngestBaseURL string
} }
// Skill implements the debug MCP tool. // Skill implements the debug MCP tool.
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
return []registry.ToolDef{ return []registry.ToolDef{
{ {
Name: "debug", Name: "debug",
Description: "Analyse an error and return 3-5 hypotheses ordered by likelihood, each with a concrete verification step.", Description: "Consult a local model to analyse an error and return hypotheses ordered by likelihood, each with a concrete verification step.",
InputSchema: schema( InputSchema: schema(
[]string{"project_root", "error"}, []string{"project_root", "error"},
map[string]any{ map[string]any{

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"time" "time"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
) )
@@ -34,7 +33,6 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
model = s.cfg.DefaultModel model = s.cfg.DefaultModel
} }
// Read session log entries (empty slice if no log exists yet).
entries, err := session.Read(s.cfg.SessionsDir, a.SessionID) entries, err := session.Read(s.cfg.SessionsDir, a.SessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf("read session log: %w", err) return nil, fmt.Errorf("read session log: %w", err)
@@ -46,39 +44,33 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
} }
taskPrompt := fmt.Sprintf( taskPrompt := fmt.Sprintf(
"SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Write structured entries to brain/raw/ via brain_write. Return JSON result when done.", "SESSION_ID: %s\n\nSESSION_LOG:\n%s\n\nReview this session log. Identify what is novel or worth preserving as organizational knowledge. Provide structured insights.",
a.SessionID, string(logJSON), a.SessionID, string(logJSON),
) )
if s.cfg.ExecutorFn == nil { if s.cfg.CompleteFunc == nil {
return nil, fmt.Errorf("no executor configured") return nil, fmt.Errorf("no executor configured")
} }
t0 := time.Now() t0 := time.Now()
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, taskPrompt)
SkillPrompt: s.cfg.SkillPrompt,
TaskPrompt: taskPrompt,
Model: model,
Tools: "Bash,Read,Write",
})
if err != nil { if err != nil {
return nil, fmt.Errorf("retrospective worker: %w", err) return nil, fmt.Errorf("retrospective model: %w", err)
} }
msg := text
if len(msg) > 200 {
msg = msg[:200]
}
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
SessionID: a.SessionID, SessionID: a.SessionID,
Timestamp: time.Now(), Timestamp: time.Now(),
Skill: "retrospective", Skill: "retrospective",
Phase: "retrospective", Phase: "retrospective",
Attempts: session.AttemptsFrom(result.Attempts), FinalStatus: "ok",
FinalStatus: result.Status, ModelUsed: model,
ModelUsed: result.ModelUsed,
DurationMs: time.Since(t0).Milliseconds(), DurationMs: time.Since(t0).Milliseconds(),
Message: result.Message, Message: msg,
}) })
b, err := json.Marshal(result) return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
if err != nil {
return nil, fmt.Errorf("marshal result: %w", err)
}
return b, nil
} }

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/skills/retrospective" "github.com/mathiasbq/supervisor/internal/skills/retrospective"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -20,20 +19,14 @@ func TestHandle_Retrospective_RequiresSessionID(t *testing.T) {
} }
func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) { func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) {
var capturedReq iexec.Request var capturedTask string
s := retrospective.New(retrospective.Config{ s := retrospective.New(retrospective.Config{
SkillPrompt: "retrospective discipline", SkillPrompt: "retrospective discipline",
DefaultModel: "ollama/test", DefaultModel: "ollama/test",
SessionsDir: t.TempDir(), // empty dir, no session file — that's OK, session.Read returns nil SessionsDir: t.TempDir(),
ExecutorFn: func(_ context.Context, req iexec.Request) (iexec.Result, error) { CompleteFunc: func(_ context.Context, _, _, user string) (string, int64, error) {
capturedReq = req capturedTask = user
return iexec.Result{ return "Key insight: the team resolved a tricky nil pointer issue via careful logging.", 75, nil
Status: "pass",
Phase: "retrospective",
Skill: "retrospective",
Verified: true,
Message: "wrote 2 entries to brain",
}, nil
}, },
}) })
@@ -41,9 +34,8 @@ func TestHandle_Retrospective_BuildsPromptWithSessionLog(t *testing.T) {
out, err := s.Handle(context.Background(), "retrospective", args) out, err := s.Handle(context.Background(), "retrospective", args)
require.NoError(t, err) require.NoError(t, err)
var result iexec.Result var result map[string]any
require.NoError(t, json.Unmarshal(out, &result)) require.NoError(t, json.Unmarshal(out, &result))
assert.Equal(t, "pass", result.Status) assert.Contains(t, result["text"], "nil pointer")
assert.Contains(t, capturedReq.SkillPrompt, "retrospective discipline") assert.Contains(t, capturedTask, "empty-session")
assert.Contains(t, capturedReq.TaskPrompt, "empty-session")
} }

View File

@@ -5,19 +5,18 @@ import (
"context" "context"
"encoding/json" "encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry" "github.com/mathiasbq/supervisor/internal/registry"
) )
// ExecutorFn allows injecting a test double for the subprocess executor. // CompleteFunc is the function used to call a local model.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
// Config holds retrospective skill configuration. // Config holds retrospective skill configuration.
type Config struct { type Config struct {
SkillPrompt string // content of retrospective.md SkillPrompt string
DefaultModel string // model to use when not specified in args DefaultModel string
SessionsDir string // path to brain/sessions/ SessionsDir string
ExecutorFn ExecutorFn // injected executor CompleteFunc CompleteFunc
} }
// Skill implements registry.Skill for the retrospective tool. // Skill implements registry.Skill for the retrospective tool.
@@ -36,7 +35,7 @@ func (s *Skill) Tools() []registry.ToolDef {
return []registry.ToolDef{ return []registry.ToolDef{
{ {
Name: "retrospective", Name: "retrospective",
Description: "Run a retrospective on a completed session. Reads the session log, identifies novel learnings, and writes structured entries to the brain for ingestion. Call at the end of each coding session.", Description: "Consult a local model to analyse a completed session and identify what is novel or worth preserving as organizational knowledge.",
InputSchema: json.RawMessage(`{ InputSchema: json.RawMessage(`{
"type": "object", "type": "object",
"required": ["session_id"], "required": ["session_id"],

View File

@@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/mathiasbq/supervisor/internal/brain" "github.com/mathiasbq/supervisor/internal/brain"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
) )
@@ -53,39 +52,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
task = brainCtx + "\n---\n\n" + task task = brainCtx + "\n---\n\n" + task
} }
if s.cfg.ExecutorFn == nil { if s.cfg.CompleteFunc == nil {
return nil, fmt.Errorf("no executor configured") return nil, fmt.Errorf("no executor configured")
} }
t0 := time.Now() t0 := time.Now()
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
SkillPrompt: s.cfg.SkillPrompt,
TaskPrompt: task,
Model: model,
Tools: "Read,Bash",
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if a.SessionID != "" && s.cfg.SessionsDir != "" { if a.SessionID != "" && s.cfg.SessionsDir != "" {
msg := text
if len(msg) > 200 {
msg = msg[:200]
}
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
SessionID: a.SessionID, SessionID: a.SessionID,
Timestamp: time.Now(), Timestamp: time.Now(),
Skill: "review", Skill: "review",
Phase: "review", Phase: "review",
ProjectRoot: a.ProjectRoot, ProjectRoot: a.ProjectRoot,
Attempts: session.AttemptsFrom(result.Attempts), FinalStatus: "ok",
FinalStatus: result.Status, ModelUsed: model,
FilePath: result.FilePath,
ModelUsed: result.ModelUsed,
DurationMs: time.Since(t0).Milliseconds(), DurationMs: time.Since(t0).Milliseconds(),
Message: result.Message, Message: msg,
}) })
} }
b, err := json.Marshal(result) return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
if err != nil {
return nil, fmt.Errorf("marshal result: %w", err)
}
return b, nil
} }

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/skills/review" "github.com/mathiasbq/supervisor/internal/skills/review"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -33,29 +32,22 @@ func TestReviewRequiresFiles(t *testing.T) {
assert.ErrorContains(t, err, "files") assert.ErrorContains(t, err, "files")
} }
func TestReviewCallsExecutor(t *testing.T) { func TestReviewCallsCompleteFunc(t *testing.T) {
called := false
var capturedTask string var capturedTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
called = true capturedTask = user
capturedTask = req.TaskPrompt return "2 warnings found: missing error handling at line 42", 80, nil
return iexec.Result{
Status: "pass", Phase: "review", Skill: "review",
Verified: true, ModelUsed: "self", Message: "2 warnings found",
}, nil
} }
sk := review.New(review.Config{SkillPrompt: "review rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) sk := review.New(review.Config{SkillPrompt: "review rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
out, err := sk.Handle(context.Background(), "review", json.RawMessage( out, err := sk.Handle(context.Background(), "review", json.RawMessage(
`{"project_root":"/tmp/proj","files":["internal/foo/foo.go"],"context":"PR: add Foo helper"}`, `{"project_root":"/tmp/proj","files":["internal/foo/foo.go"],"context":"PR: add Foo helper"}`,
)) ))
require.NoError(t, err) require.NoError(t, err)
assert.True(t, called)
assert.Contains(t, capturedTask, "internal/foo/foo.go") assert.Contains(t, capturedTask, "internal/foo/foo.go")
assert.Contains(t, capturedTask, "PR: add Foo helper") assert.Contains(t, capturedTask, "PR: add Foo helper")
var result iexec.Result var result map[string]any
require.NoError(t, json.Unmarshal(out, &result)) require.NoError(t, json.Unmarshal(out, &result))
assert.Equal(t, "pass", result.Status) assert.Contains(t, result["text"], "2 warnings found")
assert.Equal(t, "review", result.Phase)
} }

View File

@@ -5,20 +5,19 @@ import (
"context" "context"
"encoding/json" "encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry" "github.com/mathiasbq/supervisor/internal/registry"
) )
// ExecutorFn is the function signature for running a worker subprocess. // CompleteFunc is the function used to call a local model.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
// Config holds dependencies for the review skill. // Config holds dependencies for the review skill.
type Config struct { type Config struct {
SkillPrompt string SkillPrompt string
DefaultModel string DefaultModel string
ExecutorFn ExecutorFn CompleteFunc CompleteFunc
SessionsDir string SessionsDir string
IngestBaseURL string // optional: base URL of ingestion server for brain context IngestBaseURL string
} }
// Skill implements the review MCP tool. // Skill implements the review MCP tool.
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
return []registry.ToolDef{ return []registry.ToolDef{
{ {
Name: "review", Name: "review",
Description: "Perform a structured code review of the specified files. Returns findings with severity levels.", Description: "Consult a local model for a structured code review of the specified files. Returns findings with severity levels.",
InputSchema: schema( InputSchema: schema(
[]string{"project_root", "files"}, []string{"project_root", "files"},
map[string]any{ map[string]any{

View File

@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/mathiasbq/supervisor/internal/brain" "github.com/mathiasbq/supervisor/internal/brain"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
) )
@@ -57,39 +56,32 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
task = brainCtx + "\n---\n\n" + task task = brainCtx + "\n---\n\n" + task
} }
if s.cfg.ExecutorFn == nil { if s.cfg.CompleteFunc == nil {
return nil, fmt.Errorf("no executor configured") return nil, fmt.Errorf("no executor configured")
} }
t0 := time.Now() t0 := time.Now()
result, err := s.cfg.ExecutorFn(ctx, iexec.Request{ text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
SkillPrompt: s.cfg.SkillPrompt,
TaskPrompt: task,
Model: model,
Tools: "Read,Write",
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if a.SessionID != "" && s.cfg.SessionsDir != "" { if a.SessionID != "" && s.cfg.SessionsDir != "" {
msg := text
if len(msg) > 200 {
msg = msg[:200]
}
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
SessionID: a.SessionID, SessionID: a.SessionID,
Timestamp: time.Now(), Timestamp: time.Now(),
Skill: "spec", Skill: "spec",
Phase: "spec", Phase: "spec",
ProjectRoot: a.ProjectRoot, ProjectRoot: a.ProjectRoot,
Attempts: session.AttemptsFrom(result.Attempts), FinalStatus: "ok",
FinalStatus: result.Status, ModelUsed: model,
FilePath: result.FilePath,
ModelUsed: result.ModelUsed,
DurationMs: time.Since(t0).Milliseconds(), DurationMs: time.Since(t0).Milliseconds(),
Message: result.Message, Message: msg,
}) })
} }
b, err := json.Marshal(result) return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
if err != nil {
return nil, fmt.Errorf("marshal result: %w", err)
}
return b, nil
} }

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/skills/spec" "github.com/mathiasbq/supervisor/internal/skills/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -33,29 +32,22 @@ func TestSpecRequiresRequirements(t *testing.T) {
assert.ErrorContains(t, err, "requirements") assert.ErrorContains(t, err, "requirements")
} }
func TestSpecCallsExecutor(t *testing.T) { func TestSpecCallsCompleteFunc(t *testing.T) {
called := false
var capturedTask string var capturedTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
called = true capturedTask = user
capturedTask = req.TaskPrompt return "# OAuth2 Login Spec\n\n## Overview\nImplement OAuth2 login flow.", 110, nil
return iexec.Result{
Status: "pass", Phase: "spec", Skill: "spec",
FilePath: "/tmp/proj/docs/login-spec.md",
Verified: true, ModelUsed: "self", Message: "spec written: login feature",
}, nil
} }
sk := spec.New(spec.Config{SkillPrompt: "spec rules", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) sk := spec.New(spec.Config{SkillPrompt: "spec rules", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
out, err := sk.Handle(context.Background(), "spec", json.RawMessage( out, err := sk.Handle(context.Background(), "spec", json.RawMessage(
`{"project_root":"/tmp/proj","requirements":"add OAuth2 login","output_path":"docs/login-spec.md"}`, `{"project_root":"/tmp/proj","requirements":"add OAuth2 login","output_path":"docs/login-spec.md"}`,
)) ))
require.NoError(t, err) require.NoError(t, err)
assert.True(t, called)
assert.Contains(t, capturedTask, "OAuth2 login") assert.Contains(t, capturedTask, "OAuth2 login")
assert.Contains(t, capturedTask, "docs/login-spec.md") assert.Contains(t, capturedTask, "docs/login-spec.md")
var result iexec.Result var result map[string]any
require.NoError(t, json.Unmarshal(out, &result)) require.NoError(t, json.Unmarshal(out, &result))
assert.Equal(t, "spec", result.Phase) assert.Contains(t, result["text"], "OAuth2 Login Spec")
} }

View File

@@ -5,20 +5,19 @@ import (
"context" "context"
"encoding/json" "encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry" "github.com/mathiasbq/supervisor/internal/registry"
) )
// ExecutorFn is the function signature for running a worker subprocess. // CompleteFunc is the function used to call a local model.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
// Config holds dependencies for the spec skill. // Config holds dependencies for the spec skill.
type Config struct { type Config struct {
SkillPrompt string SkillPrompt string
DefaultModel string DefaultModel string
ExecutorFn ExecutorFn CompleteFunc CompleteFunc
SessionsDir string SessionsDir string
IngestBaseURL string // optional: base URL of ingestion server for brain context IngestBaseURL string
} }
// Skill implements the spec MCP tool. // Skill implements the spec MCP tool.
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
return []registry.ToolDef{ return []registry.ToolDef{
{ {
Name: "spec", Name: "spec",
Description: "Generate a structured implementation spec from requirements. Writes the spec to output_path in the project.", Description: "Consult a local model to draft a structured implementation spec from requirements. Returns the spec text.",
InputSchema: schema( InputSchema: schema(
[]string{"project_root", "requirements"}, []string{"project_root", "requirements"},
map[string]any{ map[string]any{

View File

@@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/mathiasbq/supervisor/internal/brain" "github.com/mathiasbq/supervisor/internal/brain"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
) )
@@ -51,7 +50,7 @@ func (s *Skill) handleRed(ctx context.Context, raw json.RawMessage) (json.RawMes
if brainCtx != "" { if brainCtx != "" {
task = brainCtx + "\n---\n\n" + task task = brainCtx + "\n---\n\n" + task
} }
return s.execute(ctx, task) return s.complete(ctx, s.resolveModel(args.Model), task)
} }
type greenArgs struct { type greenArgs struct {
@@ -80,11 +79,11 @@ func (s *Skill) handleGreen(ctx context.Context, raw json.RawMessage) (json.RawM
task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "green", task) task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "green", task)
t0 := time.Now() t0 := time.Now()
result, err := s.execute(ctx, task) result, err := s.complete(ctx, s.resolveModel(args.Model), task)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "green", t0, result) s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "green", s.resolveModel(args.Model), t0, result)
return result, nil return result, nil
} }
@@ -118,11 +117,11 @@ func (s *Skill) handleRefactor(ctx context.Context, raw json.RawMessage) (json.R
task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "refactor", task) task = session.PrependHistory(s.cfg.SessionsDir, args.SessionID, "refactor", task)
t0 := time.Now() t0 := time.Now()
result, err := s.execute(ctx, task) result, err := s.complete(ctx, s.resolveModel(args.Model), task)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logAttempt(args.SessionID, args.ProjectRoot, "tdd", "refactor", t0, result) s.logEntry(args.SessionID, args.ProjectRoot, "tdd", "refactor", s.resolveModel(args.Model), t0, result)
return result, nil return result, nil
} }
@@ -133,31 +132,32 @@ func (s *Skill) resolveModel(override string) string {
return s.cfg.DefaultModel return s.cfg.DefaultModel
} }
// execute calls ExecutorFn and returns the marshaled result. // complete calls CompleteFunc and returns the text as JSON.
func (s *Skill) execute(ctx context.Context, task string) (json.RawMessage, error) { func (s *Skill) complete(ctx context.Context, model, task string) (json.RawMessage, error) {
if s.cfg.ExecutorFn == nil { if s.cfg.CompleteFunc == nil {
return nil, fmt.Errorf("no executor configured") return nil, fmt.Errorf("no executor configured")
} }
req := iexec.Request{ text, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.SkillPrompt, task)
SkillPrompt: s.cfg.SkillPrompt,
TaskPrompt: task,
}
result, err := s.cfg.ExecutorFn(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return json.Marshal(result) return json.Marshal(map[string]any{"text": text, "model": model, "duration_ms": dur})
} }
// logAttempt writes a session.Entry for a completed phase if session_id is set. // logEntry writes a session.Entry for a completed phase if session_id is set.
// raw is the marshaled Result returned by execute; we unmarshal to extract fields. func (s *Skill) logEntry(sessionID, projectRoot, skill, phase, model string, t0 time.Time, raw json.RawMessage) {
func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.Time, raw json.RawMessage) {
if sessionID == "" || s.cfg.SessionsDir == "" { if sessionID == "" || s.cfg.SessionsDir == "" {
return return
} }
var result iexec.Result var msg string
if err := json.Unmarshal(raw, &result); err != nil { var result struct {
return Text string `json:"text"`
}
if err := json.Unmarshal(raw, &result); err == nil && len(result.Text) > 0 {
msg = result.Text
if len(msg) > 200 {
msg = msg[:200]
}
} }
_ = session.Append(s.cfg.SessionsDir, sessionID, session.Entry{ _ = session.Append(s.cfg.SessionsDir, sessionID, session.Entry{
SessionID: sessionID, SessionID: sessionID,
@@ -165,11 +165,9 @@ func (s *Skill) logAttempt(sessionID, projectRoot, skill, phase string, t0 time.
Skill: skill, Skill: skill,
Phase: phase, Phase: phase,
ProjectRoot: projectRoot, ProjectRoot: projectRoot,
Attempts: session.AttemptsFrom(result.Attempts), FinalStatus: "ok",
FinalStatus: result.Status, ModelUsed: model,
FilePath: result.FilePath,
ModelUsed: result.ModelUsed,
DurationMs: time.Since(t0).Milliseconds(), DurationMs: time.Since(t0).Milliseconds(),
Message: result.Message, Message: msg,
}) })
} }

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
"github.com/mathiasbq/supervisor/internal/skills/tdd" "github.com/mathiasbq/supervisor/internal/skills/tdd"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -14,7 +13,6 @@ import (
func TestTDDSkillTools(t *testing.T) { func TestTDDSkillTools(t *testing.T) {
skill := tdd.New(tdd.Config{ skill := tdd.New(tdd.Config{
SystemPrompt: "supervisor rules",
SkillPrompt: "tdd rules", SkillPrompt: "tdd rules",
}) })
tools := skill.Tools() tools := skill.Tools()
@@ -26,19 +24,19 @@ func TestTDDSkillTools(t *testing.T) {
} }
func TestTDDSkillHandleUnknown(t *testing.T) { func TestTDDSkillHandleUnknown(t *testing.T) {
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) skill := tdd.New(tdd.Config{SkillPrompt: "t"})
_, err := skill.Handle(context.Background(), "tdd_unknown", json.RawMessage(`{}`)) _, err := skill.Handle(context.Background(), "tdd_unknown", json.RawMessage(`{}`))
assert.ErrorContains(t, err, "unknown tool") assert.ErrorContains(t, err, "unknown tool")
} }
func TestTDDRedRequiresProjectRoot(t *testing.T) { func TestTDDRedRequiresProjectRoot(t *testing.T) {
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) skill := tdd.New(tdd.Config{SkillPrompt: "t"})
_, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"spec":"add two numbers"}`)) _, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"spec":"add two numbers"}`))
assert.ErrorContains(t, err, "project_root") assert.ErrorContains(t, err, "project_root")
} }
func TestTDDRedRequiresSpec(t *testing.T) { func TestTDDRedRequiresSpec(t *testing.T) {
skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) skill := tdd.New(tdd.Config{SkillPrompt: "t"})
_, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"project_root":"/tmp/proj"}`)) _, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"project_root":"/tmp/proj"}`))
assert.ErrorContains(t, err, "spec") assert.ErrorContains(t, err, "spec")
} }
@@ -51,35 +49,49 @@ func TestTDDGreenInjectsSessionHistory(t *testing.T) {
Message: "wrote failing test for Foo", Message: "wrote failing test for Foo",
})) }))
var capturedPrompt string var capturedTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
capturedPrompt = req.TaskPrompt capturedTask = user
return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil return "here is my suggestion", 100, nil
} }
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: sessDir}) sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: sessDir})
_, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage( _, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
`{"project_root":"/tmp","test_path":"internal/foo/foo_test.go","test_cmd":"go test ./...","session_id":"sess-1"}`, `{"project_root":"/tmp","test_path":"internal/foo/foo_test.go","test_cmd":"go test ./...","session_id":"sess-1"}`,
)) ))
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, capturedPrompt, "## Session history") assert.Contains(t, capturedTask, "## Session history")
assert.Contains(t, capturedPrompt, "wrote failing test for Foo") assert.Contains(t, capturedTask, "wrote failing test for Foo")
} }
func TestTDDGreenNoHistoryWhenSessionIDEmpty(t *testing.T) { func TestTDDGreenNoHistoryWhenSessionIDEmpty(t *testing.T) {
var capturedPrompt string var capturedTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { fakeFn := func(_ context.Context, _, _, user string) (string, int64, error) {
capturedPrompt = req.TaskPrompt capturedTask = user
return iexec.Result{Status: "pass", Phase: "green", Skill: "tdd", Verified: true, ModelUsed: "self", Message: "ok"}, nil return "suggestion", 50, nil
} }
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", ExecutorFn: fakeFn, SessionsDir: t.TempDir()}) sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn, SessionsDir: t.TempDir()})
_, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage( _, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
`{"project_root":"/tmp","test_path":"internal/foo/foo_test.go"}`, `{"project_root":"/tmp","test_path":"internal/foo/foo_test.go"}`,
)) ))
require.NoError(t, err) require.NoError(t, err)
assert.NotContains(t, capturedPrompt, "## Session history") assert.NotContains(t, capturedTask, "## Session history")
} }
// Ensure require is used (avoids import error). func TestTDDGreenReturnsTextJSON(t *testing.T) {
var _ = require.New fakeFn := func(_ context.Context, _, _, _ string) (string, int64, error) {
return "write a func that adds two ints", 42, nil
}
sk := tdd.New(tdd.Config{SkillPrompt: "tdd", CompleteFunc: fakeFn})
raw, err := sk.Handle(context.Background(), "tdd_green", json.RawMessage(
`{"project_root":"/tmp","test_path":"foo_test.go"}`,
))
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(raw, &result))
assert.Equal(t, "write a func that adds two ints", result["text"])
assert.Equal(t, float64(42), result["duration_ms"])
}

View File

@@ -4,17 +4,15 @@ import (
"context" "context"
"encoding/json" "encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry" "github.com/mathiasbq/supervisor/internal/registry"
) )
// ExecutorFn allows injecting a test double for the executor. // CompleteFunc is the function used to call a local model.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
type Config struct { type Config struct {
SystemPrompt string
SkillPrompt string SkillPrompt string
ExecutorFn ExecutorFn // nil = no executor (tests that don't reach execute()) CompleteFunc CompleteFunc // nil = no executor (tests that don't reach execute())
DefaultModel string DefaultModel string
SessionsDir string // optional: path to brain/sessions/ for history injection SessionsDir string // optional: path to brain/sessions/ for history injection
IngestBaseURL string // optional: base URL of ingestion server for brain context IngestBaseURL string // optional: base URL of ingestion server for brain context
@@ -44,7 +42,7 @@ func (s *Skill) Tools() []registry.ToolDef {
return []registry.ToolDef{ return []registry.ToolDef{
{ {
Name: "tdd_red", Name: "tdd_red",
Description: "Write a failing test for the described behavior. Verifies the test fails before returning.", Description: "Consult a local model for help writing a failing test for the described behavior.",
InputSchema: schema( InputSchema: schema(
[]string{"project_root", "spec"}, []string{"project_root", "spec"},
map[string]any{ map[string]any{
@@ -57,7 +55,7 @@ func (s *Skill) Tools() []registry.ToolDef {
}, },
{ {
Name: "tdd_green", Name: "tdd_green",
Description: "Write minimal implementation to make the test at test_path pass.", Description: "Consult a local model for implementation ideas to make the test at test_path pass.",
InputSchema: schema( InputSchema: schema(
[]string{"project_root", "test_path"}, []string{"project_root", "test_path"},
map[string]any{ map[string]any{
@@ -71,7 +69,7 @@ func (s *Skill) Tools() []registry.ToolDef {
}, },
{ {
Name: "tdd_refactor", Name: "tdd_refactor",
Description: "Refactor the implementation at impl_path while keeping tests green.", Description: "Consult a local model for refactoring suggestions for impl_path while keeping tests green.",
InputSchema: schema( InputSchema: schema(
[]string{"project_root", "test_path", "impl_path"}, []string{"project_root", "test_path", "impl_path"},
map[string]any{ map[string]any{

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"time" "time"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
) )
@@ -28,7 +27,7 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
if a.SessionID == "" { if a.SessionID == "" {
return nil, fmt.Errorf("session_id is required") return nil, fmt.Errorf("session_id is required")
} }
if s.cfg.ExecutorFn == nil { if s.cfg.CompleteFunc == nil {
return nil, fmt.Errorf("no executor configured") return nil, fmt.Errorf("no executor configured")
} }
@@ -42,53 +41,47 @@ func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (
return nil, fmt.Errorf("read session log: %w", err) return nil, fmt.Errorf("read session log: %w", err)
} }
// ── Step 1: Reader agent ───────────────────────────────────────────────── // ── Step 1: Reader ────────────────────────────────────────────────────────
history := session.FormatHistory(entries, "") history := session.FormatHistory(entries, "")
readerTask := fmt.Sprintf( readerTask := fmt.Sprintf(
"role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s", "role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s",
a.SessionID, s.cfg.BrainDir, history, a.SessionID, s.cfg.BrainDir, history,
) )
readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{ readerText, _, err := s.cfg.CompleteFunc(ctx, model, s.cfg.ReaderPrompt, readerTask)
SkillPrompt: s.cfg.ReaderPrompt,
TaskPrompt: readerTask,
Model: model,
Tools: "Read",
})
if err != nil { if err != nil {
return nil, fmt.Errorf("reader agent: %w", err) return nil, fmt.Errorf("reader: %w", err)
} }
// ── Step 2: Writer agent (receives reader candidates) ──────────────────── // ── Step 2: Writer (receives reader output) ───────────────────────────────
t0 := time.Now() t0 := time.Now()
writerTask := fmt.Sprintf( writerTask := fmt.Sprintf(
"role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s", "role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_analysis:\n%s",
a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput, a.SessionID, s.cfg.BrainDir, readerText,
) )
writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{ writerText, dur, err := s.cfg.CompleteFunc(ctx, model, s.cfg.WriterPrompt, writerTask)
SkillPrompt: s.cfg.WriterPrompt,
TaskPrompt: writerTask,
Model: model,
Tools: "Read,Write",
})
if err != nil { if err != nil {
return nil, fmt.Errorf("writer agent: %w", err) return nil, fmt.Errorf("writer: %w", err)
} }
msg := writerText
if len(msg) > 200 {
msg = msg[:200]
}
_ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{ _ = session.Append(s.cfg.SessionsDir, a.SessionID, session.Entry{
SessionID: a.SessionID, SessionID: a.SessionID,
Timestamp: time.Now(), Timestamp: time.Now(),
Skill: "trainer", Skill: "trainer",
Phase: "trainer", Phase: "trainer",
Attempts: session.AttemptsFrom(writerResult.Attempts), FinalStatus: "ok",
FinalStatus: writerResult.Status, ModelUsed: model,
ModelUsed: writerResult.ModelUsed,
DurationMs: time.Since(t0).Milliseconds(), DurationMs: time.Since(t0).Milliseconds(),
Message: writerResult.Message, Message: msg,
}) })
b, err := json.Marshal(writerResult) return json.Marshal(map[string]any{
if err != nil { "reader_analysis": readerText,
return nil, fmt.Errorf("marshal result: %w", err) "writer_output": writerText,
} "model": model,
return b, nil "duration_ms": dur,
})
} }

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/session"
"github.com/mathiasbq/supervisor/internal/skills/trainer" "github.com/mathiasbq/supervisor/internal/skills/trainer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -31,52 +30,44 @@ func TestTrainerRequiresSessionID(t *testing.T) {
func TestTrainerCallsReaderThenWriter(t *testing.T) { func TestTrainerCallsReaderThenWriter(t *testing.T) {
sessDir := t.TempDir() sessDir := t.TempDir()
require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{ require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{
SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass", SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "ok",
Message: "wrote failing test", FilePath: "internal/foo/foo_test.go", Message: "wrote failing test", FilePath: "internal/foo/foo_test.go",
})) }))
callCount := 0 callCount := 0
var readerTask, writerTask string var readerTask, writerTask string
fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { fakeFn := func(_ context.Context, _, sys, user string) (string, int64, error) {
callCount++ callCount++
if callCount == 1 { if callCount == 1 {
// reader call // reader call
readerTask = req.TaskPrompt readerTask = user
return iexec.Result{ return "1 sft candidate found: first-pass clean TDD", 60, nil
Status: "pass", Phase: "trainer", Skill: "trainer",
RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`,
Verified: true, ModelUsed: "self", Message: "1 sft candidate found",
}, nil
} }
// writer call // writer call
writerTask = req.TaskPrompt writerTask = user
return iexec.Result{ return "written 1 knowledge entry to brain/knowledge/tdd-patterns.md", 70, nil
Status: "pass", Phase: "trainer", Skill: "trainer",
FilePath: sessDir + "/training-data/sft/sess-1.jsonl",
Verified: true, ModelUsed: "self", Message: "1 sft pair written",
}, nil
} }
sk := trainer.New(trainer.Config{ sk := trainer.New(trainer.Config{
ReaderPrompt: "reader rules", ReaderPrompt: "reader rules",
WriterPrompt: "writer rules", WriterPrompt: "writer rules",
ExecutorFn: fakeFn, CompleteFunc: fakeFn,
SessionsDir: sessDir, SessionsDir: sessDir,
BrainDir: t.TempDir(), BrainDir: t.TempDir(),
}) })
out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`)) out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer") assert.Equal(t, 2, callCount, "complete must be called exactly twice: reader then writer")
assert.Contains(t, readerTask, "role: reader") assert.Contains(t, readerTask, "role: reader")
assert.Contains(t, readerTask, "sess-1") assert.Contains(t, readerTask, "sess-1")
assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt assert.Contains(t, readerTask, "wrote failing test")
assert.Contains(t, writerTask, "role: writer") assert.Contains(t, writerTask, "role: writer")
assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer assert.Contains(t, writerTask, "sft candidate")
var result iexec.Result var result map[string]any
require.NoError(t, json.Unmarshal(out, &result)) require.NoError(t, json.Unmarshal(out, &result))
assert.Equal(t, "trainer", result.Phase) assert.Contains(t, result["reader_analysis"], "sft candidate")
assert.Equal(t, "pass", result.Status) assert.Contains(t, result["writer_output"], "knowledge entry")
} }

View File

@@ -5,21 +5,20 @@ import (
"context" "context"
"encoding/json" "encoding/json"
iexec "github.com/mathiasbq/supervisor/internal/exec"
"github.com/mathiasbq/supervisor/internal/registry" "github.com/mathiasbq/supervisor/internal/registry"
) )
// ExecutorFn is the function signature for running a worker subprocess. // CompleteFunc is the function used to call a local model.
type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) type CompleteFunc func(ctx context.Context, model, system, user string) (string, int64, error)
// Config holds dependencies for the trainer skill. // Config holds dependencies for the trainer skill.
type Config struct { type Config struct {
ReaderPrompt string ReaderPrompt string
WriterPrompt string WriterPrompt string
DefaultModel string DefaultModel string
ExecutorFn ExecutorFn CompleteFunc CompleteFunc
SessionsDir string SessionsDir string
BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/ BrainDir string // root of brain/ directory
} }
// Skill implements the trainer MCP tool. // Skill implements the trainer MCP tool.
@@ -40,7 +39,7 @@ func (s *Skill) Tools() []registry.ToolDef {
return []registry.ToolDef{ return []registry.ToolDef{
{ {
Name: "trainer", Name: "trainer",
Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.", Description: "Consult a local model to identify learning moments from a session log and suggest knowledge to preserve in the brain.",
InputSchema: schema( InputSchema: schema(
[]string{"session_id"}, []string{"session_id"},
map[string]any{ map[string]any{