From 38fcac4cba7435483bd37b8380f0b079cdf61a24 Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Sun, 19 Apr 2026 14:06:00 +0200 Subject: [PATCH] =?UTF-8?q?feat(trainer):=20add=20trainer=20MCP=20skill=20?= =?UTF-8?q?with=20reader=E2=86=92writer=20sub-agent=20chain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reader agent scans session logs for SFT/DPO candidates; writer receives reader output and formats+writes training pairs to brain/training-data/. Adds trainer-reader.md and trainer-writer.md discipline prompts. Co-Authored-By: Claude Sonnet 4.6 --- cmd/supervisor/main.go | 20 ++++++ config/models.yaml | 2 + config/supervisor/trainer-reader.md | 31 +++++++++ config/supervisor/trainer-writer.md | 35 ++++++++++ internal/skills/trainer/handlers.go | 80 +++++++++++++++++++++++ internal/skills/trainer/handlers_test.go | 82 ++++++++++++++++++++++++ internal/skills/trainer/skill.go | 53 +++++++++++++++ 7 files changed, 303 insertions(+) create mode 100644 config/supervisor/trainer-reader.md create mode 100644 config/supervisor/trainer-writer.md create mode 100644 internal/skills/trainer/handlers.go create mode 100644 internal/skills/trainer/handlers_test.go create mode 100644 internal/skills/trainer/skill.go diff --git a/cmd/supervisor/main.go b/cmd/supervisor/main.go index de704cf..14672d0 100644 --- a/cmd/supervisor/main.go +++ b/cmd/supervisor/main.go @@ -16,6 +16,7 @@ import ( skilldebug "github.com/mathiasbq/supervisor/internal/skills/debug" "github.com/mathiasbq/supervisor/internal/skills/review" "github.com/mathiasbq/supervisor/internal/skills/spec" + "github.com/mathiasbq/supervisor/internal/skills/trainer" "github.com/mathiasbq/supervisor/internal/skills/sessionlog" "github.com/mathiasbq/supervisor/internal/skills/tdd" "github.com/mathiasbq/supervisor/internal/tier" @@ -72,6 +73,17 @@ func main() { os.Exit(1) } + trainerReaderPrompt, err := os.ReadFile(cfg.ConfigDir + "/trainer-reader.md") + if err != nil { + logger.Error("read trainer-reader.md", "path", cfg.ConfigDir+"/trainer-reader.md", "err", err) + os.Exit(1) + } + trainerWriterPrompt, err := os.ReadFile(cfg.ConfigDir + "/trainer-writer.md") + if err != nil { + logger.Error("read trainer-writer.md", "path", cfg.ConfigDir+"/trainer-writer.md", "err", err) + os.Exit(1) + } + executor := iexec.New(iexec.Config{ SystemPrompt: string(systemPrompt), LiteLLMBaseURL: cfg.LiteLLMBaseURL, @@ -123,6 +135,14 @@ func main() { ExecutorFn: executor.Run, SessionsDir: cfg.SessionsDir, })) + reg.Register(trainer.New(trainer.Config{ + ReaderPrompt: string(trainerReaderPrompt), + WriterPrompt: string(trainerWriterPrompt), + DefaultModel: models.Resolve("trainer", ""), + ExecutorFn: executor.Run, + SessionsDir: cfg.SessionsDir, + BrainDir: cfg.BrainDir, + })) srv := mcp.NewServer(reg) mux := http.NewServeMux() diff --git a/config/models.yaml b/config/models.yaml index f26fb9b..bc612b5 100644 --- a/config/models.yaml +++ b/config/models.yaml @@ -9,3 +9,5 @@ skills: review: ollama/devstral-tuned debug: ollama/deepseek-r1-tuned retrospective: ollama/qwen3-coder-30b-tuned + spec: ollama/qwen3-coder-30b-tuned + trainer: ollama/qwen3-coder-30b-tuned diff --git a/config/supervisor/trainer-reader.md b/config/supervisor/trainer-reader.md new file mode 100644 index 0000000..c1bab09 --- /dev/null +++ b/config/supervisor/trainer-reader.md @@ -0,0 +1,31 @@ +# Trainer Reader Discipline + +You scan session logs and identify candidate learning moments worth converting to training data. + +## What to look for +- **SFT candidates**: the worker did exactly the right thing — a clean pattern worth reinforcing +- **DPO candidates**: the worker first produced a wrong or suboptimal response, then corrected — you have both rejected and chosen + +## Scoring (1–5) +- 5: novel pattern, clearly correct, generalises across projects +- 4: good pattern, correct, somewhat project-specific but still useful +- 3: correct but obvious — include only if especially clean +- 2 or below: skip — too ambiguous or too context-specific + +## Output contract +Return JSON result with: +- `status`: "pass" or "error" +- `phase`: "trainer" +- `skill`: "trainer" +- `file_path`: "" +- `runner_output`: JSON array of candidates (valid JSON, not markdown): + [{"type":"sft","moment":"","prompt":"","completion":"","score":4}, + {"type":"dpo","moment":"","prompt":"","chosen":"","rejected":"","score":3}] +- `verified`: true +- `message`: "N sft candidates, M dpo candidates found" + +## Rules +1. Read all session entries in the task prompt +2. Score each entry — only include entries scoring >= 3 +3. Prompt/completion fields must be phrased to generalise: no project-specific paths or names +4. If no candidates score >= 3, return an empty array `[]` — never force low-quality candidates diff --git a/config/supervisor/trainer-writer.md b/config/supervisor/trainer-writer.md new file mode 100644 index 0000000..1947671 --- /dev/null +++ b/config/supervisor/trainer-writer.md @@ -0,0 +1,35 @@ +# Trainer Writer Discipline + +You receive candidate learning moments from the reader and write clean SFT/DPO training pairs. + +## Quality gate (apply before writing) +- SFT: prompt must be phrased so it could come from any project, not just this one +- DPO: chosen and rejected must be clearly distinguishable — skip if a reader can't tell which is better +- Never include project-specific paths, variable names, or identifiers in any pair + +## Output contract +Return JSON result with: +- `status`: "pass" (pairs written or skipped due to quality) or "error" (candidates JSON was malformed) +- `phase`: "trainer" +- `skill`: "trainer" +- `file_path`: path of the last file written (empty if nothing passed quality gate) +- `runner_output`: "N SFT pairs written to brain/training-data/sft/, M DPO pairs to brain/training-data/dpo/" or "0 pairs passed quality gate" +- `verified`: true if files were written; false if nothing passed +- `message`: "N sft + M dpo pairs for session " or "no pairs passed quality gate" + +## File format +JSONL — one JSON object per line. + +SFT: `{"prompt": "...", "completion": "..."}` +DPO: `{"prompt": "...", "chosen": "...", "rejected": "..."}` + +Write SFT to: `/training-data/sft/.jsonl` +Write DPO to: `/training-data/dpo/.jsonl` + +Append to existing files if they exist (don't overwrite). + +## Rules +1. Parse the `reader_candidates` JSON from the task prompt +2. For each candidate: apply quality gate +3. Write passing SFT candidates to sft JSONL, DPO candidates to dpo JSONL +4. If nothing passes, return status "pass" with verified: false and message "no pairs passed quality gate" diff --git a/internal/skills/trainer/handlers.go b/internal/skills/trainer/handlers.go new file mode 100644 index 0000000..7c21e8d --- /dev/null +++ b/internal/skills/trainer/handlers.go @@ -0,0 +1,80 @@ +// internal/skills/trainer/handlers.go +package trainer + +import ( + "context" + "encoding/json" + "fmt" + + iexec "github.com/mathiasbq/supervisor/internal/exec" + "github.com/mathiasbq/supervisor/internal/session" +) + +type trainArgs struct { + SessionID string `json:"session_id"` + Model string `json:"model"` +} + +// Handle dispatches the MCP tool call to the trainer handler. +func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (json.RawMessage, error) { + if tool != "trainer" { + return nil, fmt.Errorf("unknown tool: %s", tool) + } + var a trainArgs + if err := json.Unmarshal(args, &a); err != nil { + return nil, fmt.Errorf("parse args: %w", err) + } + if a.SessionID == "" { + return nil, fmt.Errorf("session_id is required") + } + if s.cfg.ExecutorFn == nil { + return nil, fmt.Errorf("no executor configured") + } + + model := a.Model + if model == "" { + model = s.cfg.DefaultModel + } + + entries, err := session.Read(s.cfg.SessionsDir, a.SessionID) + if err != nil { + return nil, fmt.Errorf("read session log: %w", err) + } + + // ── Step 1: Reader agent ───────────────────────────────────────────────── + history := session.FormatHistory(entries, "") + readerTask := fmt.Sprintf( + "role: reader\nsession_id: %s\nbrain_dir: %s\n\n%s", + a.SessionID, s.cfg.BrainDir, history, + ) + readerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{ + SkillPrompt: s.cfg.ReaderPrompt, + TaskPrompt: readerTask, + Model: model, + Tools: "Read", + }) + if err != nil { + return nil, fmt.Errorf("reader agent: %w", err) + } + + // ── Step 2: Writer agent (receives reader candidates) ──────────────────── + writerTask := fmt.Sprintf( + "role: writer\nsession_id: %s\nbrain_dir: %s\n\nreader_summary: %s\nreader_candidates:\n%s", + a.SessionID, s.cfg.BrainDir, readerResult.Message, readerResult.RunnerOutput, + ) + writerResult, err := s.cfg.ExecutorFn(ctx, iexec.Request{ + SkillPrompt: s.cfg.WriterPrompt, + TaskPrompt: writerTask, + Model: model, + Tools: "Read,Write", + }) + if err != nil { + return nil, fmt.Errorf("writer agent: %w", err) + } + + b, err := json.Marshal(writerResult) + if err != nil { + return nil, fmt.Errorf("marshal result: %w", err) + } + return b, nil +} diff --git a/internal/skills/trainer/handlers_test.go b/internal/skills/trainer/handlers_test.go new file mode 100644 index 0000000..e20704b --- /dev/null +++ b/internal/skills/trainer/handlers_test.go @@ -0,0 +1,82 @@ +// internal/skills/trainer/handlers_test.go +package trainer_test + +import ( + "context" + "encoding/json" + "testing" + + iexec "github.com/mathiasbq/supervisor/internal/exec" + "github.com/mathiasbq/supervisor/internal/session" + "github.com/mathiasbq/supervisor/internal/skills/trainer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrainerToolRegistered(t *testing.T) { + sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"}) + names := make([]string, 0) + for _, tool := range sk.Tools() { + names = append(names, tool.Name) + } + assert.Contains(t, names, "trainer") +} + +func TestTrainerRequiresSessionID(t *testing.T) { + sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"}) + _, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{}`)) + assert.ErrorContains(t, err, "session_id") +} + +func TestTrainerCallsReaderThenWriter(t *testing.T) { + sessDir := t.TempDir() + require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{ + SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass", + Message: "wrote failing test", FilePath: "internal/foo/foo_test.go", + })) + + callCount := 0 + var readerTask, writerTask string + + fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { + callCount++ + if callCount == 1 { + // reader call + readerTask = req.TaskPrompt + return iexec.Result{ + Status: "pass", Phase: "trainer", Skill: "trainer", + RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`, + Verified: true, ModelUsed: "self", Message: "1 sft candidate found", + }, nil + } + // writer call + writerTask = req.TaskPrompt + return iexec.Result{ + Status: "pass", Phase: "trainer", Skill: "trainer", + FilePath: sessDir + "/training-data/sft/sess-1.jsonl", + Verified: true, ModelUsed: "self", Message: "1 sft pair written", + }, nil + } + + sk := trainer.New(trainer.Config{ + ReaderPrompt: "reader rules", + WriterPrompt: "writer rules", + ExecutorFn: fakeFn, + SessionsDir: sessDir, + BrainDir: t.TempDir(), + }) + out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`)) + require.NoError(t, err) + + assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer") + assert.Contains(t, readerTask, "role: reader") + assert.Contains(t, readerTask, "sess-1") + assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt + assert.Contains(t, writerTask, "role: writer") + assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer + + var result iexec.Result + require.NoError(t, json.Unmarshal(out, &result)) + assert.Equal(t, "trainer", result.Phase) + assert.Equal(t, "pass", result.Status) +} diff --git a/internal/skills/trainer/skill.go b/internal/skills/trainer/skill.go new file mode 100644 index 0000000..d5ecf86 --- /dev/null +++ b/internal/skills/trainer/skill.go @@ -0,0 +1,53 @@ +// internal/skills/trainer/skill.go +package trainer + +import ( + "context" + "encoding/json" + + iexec "github.com/mathiasbq/supervisor/internal/exec" + "github.com/mathiasbq/supervisor/internal/registry" +) + +// ExecutorFn is the function signature for running a worker subprocess. +type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) + +// Config holds dependencies for the trainer skill. +type Config struct { + ReaderPrompt string + WriterPrompt string + DefaultModel string + ExecutorFn ExecutorFn + SessionsDir string + BrainDir string // root of brain/ directory; writer writes to BrainDir/training-data/ +} + +// Skill implements the trainer MCP tool. +type Skill struct{ cfg Config } + +// New creates a new trainer Skill. +func New(cfg Config) *Skill { return &Skill{cfg: cfg} } + +// Name returns the skill identifier. +func (s *Skill) Name() string { return "trainer" } + +// Tools returns the MCP tool definitions for this skill. +func (s *Skill) Tools() []registry.ToolDef { + schema := func(required []string, props map[string]any) json.RawMessage { + b, _ := json.Marshal(map[string]any{"type": "object", "required": required, "properties": props}) + return b + } + return []registry.ToolDef{ + { + Name: "trainer", + Description: "Extract SFT and DPO training pairs from a session log. Runs a reader→writer chain: reader identifies learning moments, writer formats and writes pairs to brain/training-data/.", + InputSchema: schema( + []string{"session_id"}, + map[string]any{ + "session_id": map[string]any{"type": "string"}, + "model": map[string]any{"type": "string"}, + }, + ), + }, + } +}