diff --git a/internal/routing/passrate.go b/internal/routing/passrate.go new file mode 100644 index 0000000..97ae01d --- /dev/null +++ b/internal/routing/passrate.go @@ -0,0 +1,85 @@ +package routing + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "time" +) + +// Fetcher reads /pass-rate from the brain pod with a per-skill TTL cache. +type Fetcher struct { + BaseURL string + Window string + TTL time.Duration + HTTP *http.Client + + mu sync.Mutex + cache map[string]cachedRate +} + +type cachedRate struct { + value *float64 + at time.Time +} + +type passRateResponse struct { + PassRate *float64 `json:"pass_rate"` +} + +// NewFetcher returns a Fetcher that calls baseURL + /pass-rate with the +// given window string. If ttl is zero, defaults to 60 seconds. The HTTP +// client uses a 1-second total timeout. +func NewFetcher(baseURL, window string, ttl time.Duration) *Fetcher { + if ttl == 0 { + ttl = 60 * time.Second + } + return &Fetcher{ + BaseURL: baseURL, + Window: window, + TTL: ttl, + HTTP: &http.Client{Timeout: time.Second}, + cache: make(map[string]cachedRate), + } +} + +// Get returns the pass rate for the named skill, or nil if no data exists, +// or an error if the brain is unreachable. Caches successful results. +func (f *Fetcher) Get(ctx context.Context, skill string) (*float64, error) { + f.mu.Lock() + if c, ok := f.cache[skill]; ok && time.Since(c.at) < f.TTL { + v := c.value + f.mu.Unlock() + return v, nil + } + f.mu.Unlock() + + u := fmt.Sprintf("%s/pass-rate?skill=%s&window=%s", + f.BaseURL, url.QueryEscape(skill), url.QueryEscape(f.Window)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, fmt.Errorf("passrate: build request: %w", err) + } + resp, err := f.HTTP.Do(req) + if err != nil { + return nil, fmt.Errorf("passrate: request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("passrate: server returned status %d", resp.StatusCode) + } + + var body passRateResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("passrate: decode: %w", err) + } + + f.mu.Lock() + f.cache[skill] = cachedRate{value: body.PassRate, at: time.Now()} + f.mu.Unlock() + + return body.PassRate, nil +} diff --git a/internal/routing/passrate_test.go b/internal/routing/passrate_test.go new file mode 100644 index 0000000..bd2c50e --- /dev/null +++ b/internal/routing/passrate_test.go @@ -0,0 +1,73 @@ +package routing_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/mathiasbq/supervisor/internal/routing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFetcherGetReturnsPassRate(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "/pass-rate", r.URL.Path) + assert.Equal(t, "tdd", r.URL.Query().Get("skill")) + assert.Equal(t, "7d", r.URL.Query().Get("window")) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"skill": "tdd", "pass_rate": 0.94}) + })) + defer srv.Close() + + f := routing.NewFetcher(srv.URL, "7d", time.Minute) + pr, err := f.Get(context.Background(), "tdd") + require.NoError(t, err) + require.NotNil(t, pr) + assert.InDelta(t, 0.94, *pr, 1e-9) +} + +func TestFetcherGetReturnsNilWhenNoData(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"skill": "novel", "pass_rate": nil}) + })) + defer srv.Close() + + f := routing.NewFetcher(srv.URL, "7d", time.Minute) + pr, err := f.Get(context.Background(), "novel") + require.NoError(t, err) + assert.Nil(t, pr) +} + +func TestFetcherCachesWithinTTL(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&calls, 1) + _ = json.NewEncoder(w).Encode(map[string]any{"pass_rate": 0.5}) + })) + defer srv.Close() + + f := routing.NewFetcher(srv.URL, "7d", time.Minute) + for i := 0; i < 5; i++ { + _, err := f.Get(context.Background(), "tdd") + require.NoError(t, err) + } + assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "should hit upstream once and serve four times from cache") +} + +func TestFetcherSurfacesUpstreamError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + + f := routing.NewFetcher(srv.URL, "7d", time.Minute) + pr, err := f.Get(context.Background(), "tdd") + require.Error(t, err) + assert.Nil(t, pr) +}