// internal/skills/trainer/handlers_test.go package trainer_test import ( "context" "encoding/json" "testing" iexec "github.com/mathiasbq/supervisor/internal/exec" "github.com/mathiasbq/supervisor/internal/session" "github.com/mathiasbq/supervisor/internal/skills/trainer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTrainerToolRegistered(t *testing.T) { sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"}) names := make([]string, 0) for _, tool := range sk.Tools() { names = append(names, tool.Name) } assert.Contains(t, names, "trainer") } func TestTrainerRequiresSessionID(t *testing.T) { sk := trainer.New(trainer.Config{ReaderPrompt: "r", WriterPrompt: "w"}) _, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{}`)) assert.ErrorContains(t, err, "session_id") } func TestTrainerCallsReaderThenWriter(t *testing.T) { sessDir := t.TempDir() require.NoError(t, session.Append(sessDir, "sess-1", session.Entry{ SessionID: "sess-1", Skill: "tdd", Phase: "red", FinalStatus: "pass", Message: "wrote failing test", FilePath: "internal/foo/foo_test.go", })) callCount := 0 var readerTask, writerTask string fakeFn := func(_ context.Context, req iexec.Request) (iexec.Result, error) { callCount++ if callCount == 1 { // reader call readerTask = req.TaskPrompt return iexec.Result{ Status: "pass", Phase: "trainer", Skill: "trainer", RunnerOutput: `[{"type":"sft","moment":"first-pass clean TDD","score":4}]`, Verified: true, ModelUsed: "self", Message: "1 sft candidate found", }, nil } // writer call writerTask = req.TaskPrompt return iexec.Result{ Status: "pass", Phase: "trainer", Skill: "trainer", FilePath: sessDir + "/training-data/sft/sess-1.jsonl", Verified: true, ModelUsed: "self", Message: "1 sft pair written", }, nil } sk := trainer.New(trainer.Config{ ReaderPrompt: "reader rules", WriterPrompt: "writer rules", ExecutorFn: fakeFn, SessionsDir: sessDir, BrainDir: t.TempDir(), }) out, err := sk.Handle(context.Background(), "trainer", json.RawMessage(`{"session_id":"sess-1"}`)) require.NoError(t, err) assert.Equal(t, 2, callCount, "executor must be called exactly twice: reader then writer") assert.Contains(t, readerTask, "role: reader") assert.Contains(t, readerTask, "sess-1") assert.Contains(t, readerTask, "wrote failing test") // session history in reader prompt assert.Contains(t, writerTask, "role: writer") assert.Contains(t, writerTask, "sft candidate") // reader output passed to writer var result iexec.Result require.NoError(t, json.Unmarshal(out, &result)) assert.Equal(t, "trainer", result.Phase) assert.Equal(t, "pass", result.Status) }