198 lines
5.1 KiB
Go
198 lines
5.1 KiB
Go
package exec
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ChainEntry is one tier in an escalation chain.
|
|
type ChainEntry struct {
|
|
Model string // e.g. "ollama/phi4", "claude-sonnet-4-6"
|
|
Tier string // "local" | "subagent" | "managed"
|
|
IsCloud bool // true for claude-* models; skips verifier call
|
|
}
|
|
|
|
// EntryFor builds a ChainEntry from a model name string.
|
|
func EntryFor(model string) ChainEntry {
|
|
cloud := strings.HasPrefix(model, "claude-")
|
|
tier := "local"
|
|
if cloud {
|
|
tier = "subagent"
|
|
}
|
|
return ChainEntry{Model: model, Tier: tier, IsCloud: cloud}
|
|
}
|
|
|
|
// AttemptRecord captures the outcome of one tier attempt for session logging.
|
|
type AttemptRecord struct {
|
|
Model string
|
|
Tier string
|
|
DurationMs int64
|
|
WarmStart bool
|
|
Verdict string // "accept" | "escalate" | "error"
|
|
Feedback string
|
|
}
|
|
|
|
// VerifierFn is the interface the orchestrator uses to verify local output.
|
|
type VerifierFn interface {
|
|
Verify(ctx context.Context, skillPrompt, taskPrompt string, output Result) (Verdict, error)
|
|
}
|
|
|
|
// ExecutorRunFn is the signature of Executor.Run and LiteLLMExecutor.Run.
|
|
type ExecutorRunFn func(ctx context.Context, req Request) (Result, error)
|
|
|
|
// Orchestrator walks an escalation chain, delegating generation and verification.
|
|
// It implements the ExecutorFn shape expected by skill handlers.
|
|
type Orchestrator struct {
|
|
chain []ChainEntry
|
|
localRun ExecutorRunFn // for local (non-cloud) tiers; may be nil
|
|
cloudRun ExecutorRunFn // for cloud tiers; may be nil
|
|
verifier VerifierFn
|
|
llamaSwapURL string
|
|
attempts *[]AttemptRecord
|
|
}
|
|
|
|
// NewOrchestrator creates an Orchestrator.
|
|
// attempts is a pointer to a slice that will be appended to on each tier attempt.
|
|
// Pass nil for localRun or cloudRun if no tiers of that type exist in the chain.
|
|
func NewOrchestrator(
|
|
chain []ChainEntry,
|
|
localRun ExecutorRunFn,
|
|
cloudRun ExecutorRunFn,
|
|
verifier VerifierFn,
|
|
llamaSwapURL string,
|
|
attempts *[]AttemptRecord,
|
|
) *Orchestrator {
|
|
return &Orchestrator{
|
|
chain: chain,
|
|
localRun: localRun,
|
|
cloudRun: cloudRun,
|
|
verifier: verifier,
|
|
llamaSwapURL: llamaSwapURL,
|
|
attempts: attempts,
|
|
}
|
|
}
|
|
|
|
// Run walks the escalation chain and returns the first accepted result.
|
|
// Satisfies the ExecutorFn signature: func(context.Context, Request) (Result, error).
|
|
func (o *Orchestrator) Run(ctx context.Context, req Request) (Result, error) {
|
|
taskPrompt := req.TaskPrompt
|
|
|
|
for _, entry := range o.chain {
|
|
warm := o.probeWarm(entry.Model)
|
|
start := time.Now()
|
|
|
|
tierReq := req
|
|
tierReq.Model = entry.Model
|
|
tierReq.TaskPrompt = taskPrompt
|
|
|
|
if entry.IsCloud {
|
|
result, genErr := o.cloudRun(ctx, tierReq)
|
|
dur := time.Since(start).Milliseconds()
|
|
verdict := "accept"
|
|
if genErr != nil {
|
|
verdict = "error"
|
|
}
|
|
o.appendAttempt(AttemptRecord{
|
|
Model: entry.Model,
|
|
Tier: entry.Tier,
|
|
DurationMs: dur,
|
|
WarmStart: warm,
|
|
Verdict: verdict,
|
|
})
|
|
if genErr == nil {
|
|
return result, nil
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Local tier.
|
|
result, genErr := o.localRun(ctx, tierReq)
|
|
dur := time.Since(start).Milliseconds()
|
|
|
|
if genErr != nil {
|
|
o.appendAttempt(AttemptRecord{
|
|
Model: entry.Model,
|
|
Tier: entry.Tier,
|
|
DurationMs: dur,
|
|
WarmStart: warm,
|
|
Verdict: "error",
|
|
Feedback: genErr.Error(),
|
|
})
|
|
continue
|
|
}
|
|
|
|
verdict, verErr := o.verifier.Verify(ctx, req.SkillPrompt, taskPrompt, result)
|
|
if verErr != nil {
|
|
// Treat verifier failure as escalate (safe default).
|
|
o.appendAttempt(AttemptRecord{
|
|
Model: entry.Model,
|
|
Tier: entry.Tier,
|
|
DurationMs: dur,
|
|
WarmStart: warm,
|
|
Verdict: "escalate",
|
|
Feedback: "verifier error: " + verErr.Error(),
|
|
})
|
|
continue
|
|
}
|
|
|
|
if verdict.Accept {
|
|
o.appendAttempt(AttemptRecord{
|
|
Model: entry.Model,
|
|
Tier: entry.Tier,
|
|
DurationMs: dur,
|
|
WarmStart: warm,
|
|
Verdict: "accept",
|
|
})
|
|
return result, nil
|
|
}
|
|
|
|
o.appendAttempt(AttemptRecord{
|
|
Model: entry.Model,
|
|
Tier: entry.Tier,
|
|
DurationMs: dur,
|
|
WarmStart: warm,
|
|
Verdict: "escalate",
|
|
Feedback: verdict.Feedback,
|
|
})
|
|
// Inject verifier feedback into the next tier's task prompt.
|
|
taskPrompt = taskPrompt + "\n\nPrior attempt feedback: " + verdict.Feedback
|
|
}
|
|
|
|
return Result{}, fmt.Errorf("all tiers exhausted after %d attempt(s)", len(o.chain))
|
|
}
|
|
|
|
func (o *Orchestrator) appendAttempt(rec AttemptRecord) {
|
|
if o.attempts != nil {
|
|
*o.attempts = append(*o.attempts, rec)
|
|
}
|
|
}
|
|
|
|
// probeWarm checks whether the model is currently loaded in llama-swap.
|
|
// Returns false on any error or if llamaSwapURL is empty.
|
|
func (o *Orchestrator) probeWarm(model string) bool {
|
|
if o.llamaSwapURL == "" {
|
|
return false
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, o.llamaSwapURL+"/v1/models", nil)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer resp.Body.Close() //nolint:errcheck
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return strings.Contains(string(body), model)
|
|
}
|