diff --git a/internal/skills/tdd/handlers.go b/internal/skills/tdd/handlers.go new file mode 100644 index 0000000..98d904a --- /dev/null +++ b/internal/skills/tdd/handlers.go @@ -0,0 +1,123 @@ +package tdd + +import ( + "context" + "encoding/json" + "fmt" + + iexec "github.com/mathiasbq/supervisor/internal/exec" +) + +func (s *Skill) Handle(ctx context.Context, tool string, args json.RawMessage) (json.RawMessage, error) { + switch tool { + case "tdd_red": + return s.handleRed(ctx, args) + case "tdd_green": + return s.handleGreen(ctx, args) + case "tdd_refactor": + return s.handleRefactor(ctx, args) + default: + return nil, fmt.Errorf("unknown tool: %s", tool) + } +} + +type redArgs struct { + ProjectRoot string `json:"project_root"` + Spec string `json:"spec"` + Model string `json:"model"` + TestCmd string `json:"test_cmd"` +} + +func (s *Skill) handleRed(ctx context.Context, raw json.RawMessage) (json.RawMessage, error) { + var args redArgs + if err := json.Unmarshal(raw, &args); err != nil { + return nil, fmt.Errorf("parse args: %w", err) + } + if args.ProjectRoot == "" { + return nil, fmt.Errorf("project_root is required") + } + if args.Spec == "" { + return nil, fmt.Errorf("spec is required") + } + task := fmt.Sprintf( + "phase: red\nproject_root: %s\nspec: %s\nmodel: %s\ntest_cmd: %s", + args.ProjectRoot, args.Spec, s.resolveModel(args.Model), args.TestCmd, + ) + return s.execute(ctx, task) +} + +type greenArgs struct { + ProjectRoot string `json:"project_root"` + TestPath string `json:"test_path"` + Model string `json:"model"` + TestCmd string `json:"test_cmd"` +} + +func (s *Skill) handleGreen(ctx context.Context, raw json.RawMessage) (json.RawMessage, error) { + var args greenArgs + if err := json.Unmarshal(raw, &args); err != nil { + return nil, fmt.Errorf("parse args: %w", err) + } + if args.ProjectRoot == "" { + return nil, fmt.Errorf("project_root is required") + } + if args.TestPath == "" { + return nil, fmt.Errorf("test_path is required") + } + task := fmt.Sprintf( + "phase: green\nproject_root: %s\ntest_path: %s\nmodel: %s\ntest_cmd: %s", + args.ProjectRoot, args.TestPath, s.resolveModel(args.Model), args.TestCmd, + ) + return s.execute(ctx, task) +} + +type refactorArgs struct { + ProjectRoot string `json:"project_root"` + TestPath string `json:"test_path"` + ImplPath string `json:"impl_path"` + Model string `json:"model"` + TestCmd string `json:"test_cmd"` +} + +func (s *Skill) handleRefactor(ctx context.Context, raw json.RawMessage) (json.RawMessage, error) { + var args refactorArgs + if err := json.Unmarshal(raw, &args); err != nil { + return nil, fmt.Errorf("parse args: %w", err) + } + if args.ProjectRoot == "" { + return nil, fmt.Errorf("project_root is required") + } + if args.TestPath == "" { + return nil, fmt.Errorf("test_path is required") + } + if args.ImplPath == "" { + return nil, fmt.Errorf("impl_path is required") + } + task := fmt.Sprintf( + "phase: refactor\nproject_root: %s\ntest_path: %s\nimpl_path: %s\nmodel: %s\ntest_cmd: %s", + args.ProjectRoot, args.TestPath, args.ImplPath, s.resolveModel(args.Model), args.TestCmd, + ) + return s.execute(ctx, task) +} + +func (s *Skill) resolveModel(override string) string { + if override != "" { + return override + } + return s.cfg.DefaultModel +} + +func (s *Skill) execute(ctx context.Context, task string) (json.RawMessage, error) { + if s.cfg.ExecutorFn == nil { + return nil, fmt.Errorf("no executor configured") + } + req := iexec.Request{ + SkillPrompt: s.cfg.SkillPrompt, + TaskPrompt: task, + } + result, err := s.cfg.ExecutorFn(ctx, req) + if err != nil { + return nil, err + } + return json.Marshal(result) +} diff --git a/internal/skills/tdd/handlers_test.go b/internal/skills/tdd/handlers_test.go new file mode 100644 index 0000000..d0490b6 --- /dev/null +++ b/internal/skills/tdd/handlers_test.go @@ -0,0 +1,45 @@ +package tdd_test + +import ( + "context" + "encoding/json" + "testing" + + "github.com/mathiasbq/supervisor/internal/skills/tdd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTDDSkillTools(t *testing.T) { + skill := tdd.New(tdd.Config{ + SystemPrompt: "supervisor rules", + SkillPrompt: "tdd rules", + }) + tools := skill.Tools() + names := make([]string, len(tools)) + for i, tool := range tools { + names[i] = tool.Name + } + assert.ElementsMatch(t, []string{"tdd_red", "tdd_green", "tdd_refactor"}, names) +} + +func TestTDDSkillHandleUnknown(t *testing.T) { + skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) + _, err := skill.Handle(context.Background(), "tdd_unknown", json.RawMessage(`{}`)) + assert.ErrorContains(t, err, "unknown tool") +} + +func TestTDDRedRequiresProjectRoot(t *testing.T) { + skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) + _, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"spec":"add two numbers"}`)) + assert.ErrorContains(t, err, "project_root") +} + +func TestTDDRedRequiresSpec(t *testing.T) { + skill := tdd.New(tdd.Config{SystemPrompt: "s", SkillPrompt: "t"}) + _, err := skill.Handle(context.Background(), "tdd_red", json.RawMessage(`{"project_root":"/tmp/proj"}`)) + assert.ErrorContains(t, err, "spec") +} + +// Ensure require is used (avoids import error). +var _ = require.New diff --git a/internal/skills/tdd/skill.go b/internal/skills/tdd/skill.go new file mode 100644 index 0000000..c3fdbc6 --- /dev/null +++ b/internal/skills/tdd/skill.go @@ -0,0 +1,84 @@ +package tdd + +import ( + "context" + "encoding/json" + + iexec "github.com/mathiasbq/supervisor/internal/exec" + "github.com/mathiasbq/supervisor/internal/registry" +) + +// ExecutorFn allows injecting a test double for the executor. +type ExecutorFn func(ctx context.Context, req iexec.Request) (iexec.Result, error) + +type Config struct { + SystemPrompt string + SkillPrompt string + ExecutorFn ExecutorFn // nil = no executor (tests that don't reach execute()) + DefaultModel string +} + +type Skill struct { + cfg Config +} + +func New(cfg Config) *Skill { + return &Skill{cfg: cfg} +} + +func (s *Skill) Name() string { return "tdd" } + +func (s *Skill) Tools() []registry.ToolDef { + schema := func(required []string, props map[string]any) json.RawMessage { + b, _ := json.Marshal(map[string]any{ + "type": "object", + "required": required, + "properties": props, + }) + return b + } + strProp := map[string]any{"type": "string"} + + return []registry.ToolDef{ + { + Name: "tdd_red", + Description: "Write a failing test for the described behavior. Verifies the test fails before returning.", + InputSchema: schema( + []string{"project_root", "spec"}, + map[string]any{ + "project_root": strProp, + "spec": strProp, + "model": strProp, + "test_cmd": strProp, + }, + ), + }, + { + Name: "tdd_green", + Description: "Write minimal implementation to make the test at test_path pass.", + InputSchema: schema( + []string{"project_root", "test_path"}, + map[string]any{ + "project_root": strProp, + "test_path": strProp, + "model": strProp, + "test_cmd": strProp, + }, + ), + }, + { + Name: "tdd_refactor", + Description: "Refactor the implementation at impl_path while keeping tests green.", + InputSchema: schema( + []string{"project_root", "test_path", "impl_path"}, + map[string]any{ + "project_root": strProp, + "test_path": strProp, + "impl_path": strProp, + "model": strProp, + "test_cmd": strProp, + }, + ), + }, + } +}