Files
hyperguild/internal/exec/orchestrator.go

198 lines
5.1 KiB
Go

package exec
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ChainEntry is one tier in an escalation chain.
type ChainEntry struct {
Model string // e.g. "ollama/phi4", "claude-sonnet-4-6"
Tier string // "local" | "subagent" | "managed"
IsCloud bool // true for claude-* models; skips verifier call
}
// EntryFor builds a ChainEntry from a model name string.
func EntryFor(model string) ChainEntry {
cloud := strings.HasPrefix(model, "claude-")
tier := "local"
if cloud {
tier = "subagent"
}
return ChainEntry{Model: model, Tier: tier, IsCloud: cloud}
}
// AttemptRecord captures the outcome of one tier attempt for session logging.
type AttemptRecord struct {
Model string
Tier string
DurationMs int64
WarmStart bool
Verdict string // "accept" | "escalate" | "error"
Feedback string
}
// VerifierFn is the interface the orchestrator uses to verify local output.
type VerifierFn interface {
Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error)
}
// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run.
type ExecutorRunFn func(ctx context.Context, req Request) (Result, error)
// Orchestrator walks an escalation chain, delegating generation and verification.
// It implements the ExecutorFn shape expected by skill handlers.
type Orchestrator struct {
chain []ChainEntry
localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil
cloudRun ExecutorRunFn // for cloud tiers; may be nil
verifier VerifierFn
llamaSwapURL string
attempts *[]AttemptRecord
}
// NewOrchestrator creates an Orchestrator.
// attempts is a pointer to a slice that will be appended to on each tier attempt.
// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain.
func NewOrchestrator(
chain []ChainEntry,
localRun ExecutorRunFn,
cloudRun ExecutorRunFn,
verifier VerifierFn,
llamaSwapURL string,
attempts *[]AttemptRecord,
) *Orchestrator {
return &Orchestrator{
chain: chain,
localRun: localRun,
cloudRun: cloudRun,
verifier: verifier,
llamaSwapURL: llamaSwapURL,
attempts: attempts,
}
}
// Run walks the escalation chain and returns the first accepted result.
// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error).
func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) {
taskPrompt := req.TaskPrompt
for _, entry := range o.chain {
warm := o.probeWarm(entry.Model)
start := time.Now()
tierReq := req
tierReq.Model = entry.Model
tierReq.TaskPrompt = taskPrompt
if entry.IsCloud {
result, genErr := o.cloudRun(ctx, tierReq)
dur := time.Since(start).Milliseconds()
verdict := "accept"
if genErr != nil {
verdict = "error"
}
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: verdict,
})
if genErr == nil {
return result, nil
}
continue
}
// Local tier.
result, genErr := o.localRun(ctx, tierReq)
dur := time.Since(start).Milliseconds()
if genErr != nil {
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "error",
Feedback: genErr.Error(),
})
continue
}
verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result)
if verErr != nil {
// Treat verifier failure as escalate (safe default).
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "escalate",
Feedback: "verifier error: " + verErr.Error(),
})
continue
}
if verdict.Accept {
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "accept",
})
return result, nil
}
o.appendAttempt(AttemptRecord{
Model: entry.Model,
Tier: entry.Tier,
DurationMs: dur,
WarmStart: warm,
Verdict: "escalate",
Feedback: verdict.Feedback,
})
// Inject verifier feedback into the next tier's task prompt.
taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback
}
return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain))
}
func (o *Orchestrator) appendAttempt(rec AttemptRecord) {
if o.attempts != nil {
*o.attempts = append(*o.attempts, rec)
}
}
// probeWarm checks whether the model is currently loaded in llama-swap.
// Returns false on any error or if llamaSwapURL is empty.
func (o *Orchestrator) probeWarm(model string) bool {
if o.llamaSwapURL == "" {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil)
if err != nil {
return false
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false
}
defer resp.Body.Close() //nolint:errcheck
body, err := io.ReadAll(resp.Body)
if err != nil {
return false
}
return strings.Contains(string(body), model)
}