diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6147740..390d375 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,8 @@ jobs: with: go-version: ${{ env.GO_VERSION }} cache: true + - name: Boundary guard + run: bash ./scripts/check-ecosystem-boundaries.sh - name: gofumpt run: | go install mvdan.cc/gofumpt@v0.10.0 @@ -118,6 +120,8 @@ jobs: cache: true - name: Clone tok (workspace dep) run: git clone --depth=1 https://github.com/GrayCodeAI/tok.git ../tok + - name: Boundary guard + run: bash ./scripts/check-ecosystem-boundaries.sh - name: Run golangci-lint run: | go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.1.0 @@ -138,6 +142,8 @@ jobs: cache: true - name: Clone tok (workspace dep) run: git clone --depth=1 https://github.com/GrayCodeAI/tok.git ../tok + - name: Boundary guard + run: bash ./scripts/check-ecosystem-boundaries.sh - name: Test with race detector run: go test ./... -race -count=1 -shuffle=on -coverprofile=coverage.out -covermode=atomic -timeout=300s - name: Coverage summary diff --git a/AGENTS.md b/AGENTS.md index 22aad91..f71fee2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -46,7 +46,7 @@ make ci # Full CI suite - Provider interface is the boundary — keep it stable - Streaming tests need careful goroutine management -- `go.work` here replaces only `github.com/GrayCodeAI/tok => ../tok`; hawk's own `go.work` adds an `external/eyrie` replace so hawk can develop against a local eyrie checkout. Do not add other `replace` directives here without coordinating with hawk's workspace. +- `go.work` here should stay minimal; hawk's own `go.work` adds an `external/eyrie` replace so hawk can develop against a local eyrie checkout. Do not add extra local `replace` directives here without coordinating with hawk's workspace. ## Naming Conventions diff --git a/Makefile b/Makefile index 3d567cf..ba51a34 100644 --- a/Makefile +++ b/Makefile @@ -31,9 +31,12 @@ GOVULNCHECK := $(GOBIN_DIR)/govulncheck # --------------------------------------------------------------------------- # Phony declarations (alphabetical). # --------------------------------------------------------------------------- -.PHONY: all bench build ci clean cover fmt help lint lint-fix \ +.PHONY: all bench boundaries build ci clean cover fmt help lint lint-fix \ security test test-10x test-race tidy version vet +boundaries: ## Enforce support-repo import boundaries. + bash ./scripts/check-ecosystem-boundaries.sh + # --------------------------------------------------------------------------- # Default target. # --------------------------------------------------------------------------- @@ -97,7 +100,7 @@ tidy: ## Tidy go.mod / go.sum. # --------------------------------------------------------------------------- # Composite gate used by CI and pre-push. # --------------------------------------------------------------------------- -ci: tidy fmt vet lint test-race security ## Run everything CI runs. +ci: tidy fmt vet lint boundaries test-race security ## Run everything CI runs. @echo "All CI checks passed." # --------------------------------------------------------------------------- diff --git a/README.md b/README.md index 6060003..6688cde 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,14 @@ When your app calls a model, eyrie figures out which provider to use, how to tal **Your app never talks to an LLM API directly. eyrie does.** +## Ecosystem Boundaries + +eyrie is a Hawk support engine. Keep the dependency edge one-way: + +- depend on `hawk-core-contracts` when a stable cross-repo contract is needed +- do not import `hawk/internal/*` +- do not import removed legacy path `hawk/shared/types`; use `hawk-core-contracts/types` + ## Quick Start ```bash diff --git a/catalog/discover/merge.go b/catalog/discover/merge.go index 64c73c9..86c7111 100644 --- a/catalog/discover/merge.go +++ b/catalog/discover/merge.go @@ -87,9 +87,9 @@ func MergeCatalogV1WithPolicy(dst, src *catalog.CatalogV1, policy MergePolicy) * } continue } - if dst.Models[id].ID == "" { - dst.Models[id] = m - } + // Key is absent here (the ok branch above continues), so the model is + // always new — assign unconditionally. + dst.Models[id] = m } seen := map[string]int{} for i, o := range dst.Offerings { diff --git a/catalog/live/fetchers.go b/catalog/live/fetchers.go index 0aeeb16..462d893 100644 --- a/catalog/live/fetchers.go +++ b/catalog/live/fetchers.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "io" "net/http" "sort" "strings" @@ -15,6 +16,18 @@ import ( "github.com/GrayCodeAI/eyrie/catalog/opencodego" ) +// maxLiveResponseBytes caps how much of an external provider's HTTP response +// the live catalog fetchers will read, so a malicious or buggy provider cannot +// exhaust memory by returning an unbounded body. +const maxLiveResponseBytes = 10 * 1024 * 1024 // 10 MiB + +// decodeJSONLimited decodes JSON from r into v, reading at most +// maxLiveResponseBytes. Use this instead of json.NewDecoder(resp.Body) for +// responses from untrusted/remote endpoints. +func decodeJSONLimited(r io.Reader, v any) error { + return json.NewDecoder(io.LimitReader(r, maxLiveResponseBytes)).Decode(v) +} + // Provider FetchFunc implementations live in fetchers_cloud.go and // fetchers_providers.go; this file holds the registry, shared parsing/pricing // helpers, and AWS SigV4 signing helpers. @@ -166,7 +179,7 @@ func fetchOpenAICompatModels(ctx context.Context, baseURL, apiKey, authHeader st var payload struct { Data []json.RawMessage `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry diff --git a/catalog/live/fetchers_cloud.go b/catalog/live/fetchers_cloud.go index e5d2e73..f759eac 100644 --- a/catalog/live/fetchers_cloud.go +++ b/catalog/live/fetchers_cloud.go @@ -52,7 +52,7 @@ func enrichOpenAIWithOpenRouter(entries []Entry) { var payload struct { Data []openRouterModelEntry `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return } // Build lookup map: "gpt-4o" → openRouterModelEntry @@ -136,7 +136,7 @@ func enrichFromOpenRouter(entries []Entry, prefix string) { var payload struct { Data []openRouterModelEntry `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return } // Build lookup map by stripping prefix @@ -239,7 +239,7 @@ func FetchAzure(env map[string]string) ([]Entry, error) { var payload struct { Value []json.RawMessage `json:"value"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry @@ -304,7 +304,7 @@ func FetchBedrock(env map[string]string) ([]Entry, error) { var payload struct { ModelSummaries []json.RawMessage `json:"modelSummaries"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry @@ -377,7 +377,7 @@ func FetchVertex(env map[string]string) ([]Entry, error) { PublisherModels []json.RawMessage `json:"publisherModels"` Models []json.RawMessage `json:"models"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } rawModels := payload.PublisherModels diff --git a/catalog/live/fetchers_providers.go b/catalog/live/fetchers_providers.go index 8c88f29..32467ac 100644 --- a/catalog/live/fetchers_providers.go +++ b/catalog/live/fetchers_providers.go @@ -82,6 +82,7 @@ func FetchOpenCodeGo(env map[string]string) ([]Entry, error) { if err != nil { return nil, err } + protocolEntries := make([]struct{ ID, Protocol string }, 0, len(entries)) for i := range entries { entries[i].ID = opencodego.NativeModelID(entries[i].ID) // Merge with static metadata from docs (pricing, protocol, context windows). @@ -91,7 +92,9 @@ func FetchOpenCodeGo(env map[string]string) ([]Entry, error) { // Unknown model — derive protocol from name pattern. entries[i].Protocol = opencodego.ProtocolForModel(entries[i].ID) } + protocolEntries = append(protocolEntries, struct{ ID, Protocol string }{ID: entries[i].ID, Protocol: entries[i].Protocol}) } + opencodego.UpdateProtocolMap(protocolEntries) return entries, nil } @@ -230,7 +233,7 @@ func FetchOpenRouter(env map[string]string) ([]Entry, error) { var payload struct { Data []json.RawMessage `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry @@ -350,7 +353,7 @@ func FetchAnthropic(env map[string]string) ([]Entry, error) { var payload struct { Data []json.RawMessage `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry @@ -471,7 +474,7 @@ func FetchGemini(env map[string]string) ([]Entry, error) { var payload struct { Models []json.RawMessage `json:"models"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry @@ -539,7 +542,7 @@ func FetchOllama(env map[string]string) ([]Entry, error) { var payload struct { Models []json.RawMessage `json:"models"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := decodeJSONLimited(resp.Body, &payload); err != nil { return nil, err } var entries []Entry diff --git a/catalog/live/opencodego_test.go b/catalog/live/opencodego_test.go index 54f3072..8c6a6c7 100644 --- a/catalog/live/opencodego_test.go +++ b/catalog/live/opencodego_test.go @@ -5,6 +5,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/GrayCodeAI/eyrie/catalog/opencodego" ) func TestFetchOpenCodeGo_MockHTTPServer(t *testing.T) { @@ -45,6 +47,44 @@ func TestFetchOpenCodeGo_MockHTTPServer(t *testing.T) { } } +func TestFetchOpenCodeGoUpdatesProtocolMap(t *testing.T) { + opencodego.ResetProtocolMap() + defer opencodego.ResetProtocolMap() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/models" { + http.NotFound(w, r) + return + } + resp := struct { + Data []json.RawMessage `json:"data"` + }{ + Data: []json.RawMessage{ + json.RawMessage(`{"id":"new-live-model","owned_by":"opencode","api_type":"anthropic"}`), + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + if opencodego.UsesMessagesAPI("new-live-model") { + t.Fatal("expected heuristic fallback to use OpenAI before live fetch") + } + entries, err := FetchOpenCodeGo(map[string]string{ + "OPENCODEGO_API_KEY": "test-ocg-key", + "OPENCODEGO_BASE_URL": srv.URL, + }) + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 model, got %d", len(entries)) + } + if !opencodego.UsesMessagesAPI("new-live-model") { + t.Fatal("expected live fetch to route anthropic model to Messages API") + } +} + func TestFetchOpenCodeGo_NoKey(t *testing.T) { entries, err := FetchOpenCodeGo(map[string]string{}) if err != nil { diff --git a/client/budget_provider.go b/client/budget_provider.go index 69af354..e7c026e 100644 --- a/client/budget_provider.go +++ b/client/budget_provider.go @@ -103,7 +103,34 @@ func (bp *BudgetProvider) StreamChat(ctx context.Context, messages []EyrieMessag if err := bp.store.CheckBudget(ctx, vk, est.TotalCostUSD); err != nil { return nil, err } - return bp.inner.StreamChat(ctx, messages, opts) + + result, err := bp.inner.StreamChat(ctx, messages, opts) + if err != nil { + return nil, err + } + + // Wrap the events channel to record actual spend from the final usage + // event. Without this, streamed calls under a virtual key never debit the + // budget (unlike the non-streaming Chat path), so streaming-heavy clients + // would underreport spend. Mirrors UsageLimitProvider.StreamChat. + wrappedCh := make(chan EyrieStreamEvent, cap(result.Events)) + go func() { + defer close(wrappedCh) + for evt := range result.Events { + if evt.Type == "usage" && evt.Usage != nil { + cost := ActualCostUSD(opts.Model, evt.Usage) + _ = bp.store.RecordUsage(ctx, vk, cost, evt.Usage.PromptTokens, evt.Usage.CompletionTokens) + } + select { + case wrappedCh <- evt: + case <-ctx.Done(): + result.Close() + return + } + } + }() + + return NewStreamResult(wrappedCh, result.Close), nil } func (bp *BudgetProvider) recordUsage(ctx context.Context, vk, model string, resp *EyrieResponse) { diff --git a/client/condenser.go b/client/condenser.go index 28bbbfd..705f65b 100644 --- a/client/condenser.go +++ b/client/condenser.go @@ -148,7 +148,11 @@ func (c *LLMSummarizingCondenser) summarize(ctx context.Context, span []EyrieMes for _, m := range span { b.WriteString(m.Role) b.WriteString(": ") - b.WriteString(m.Content) + // Indent embedded newlines so a message body cannot forge a new + // "role:" turn at column 0 (e.g. content containing "\nassistant: ..."). + // Without this, a flat "role: content" transcript is a prompt-injection + // vector into the summarization call. + b.WriteString(strings.ReplaceAll(m.Content, "\n", "\n ")) b.WriteByte('\n') } diff --git a/client/cost_estimator.go b/client/cost_estimator.go index fa1165e..2f77efb 100644 --- a/client/cost_estimator.go +++ b/client/cost_estimator.go @@ -4,8 +4,6 @@ import ( "fmt" "strings" "sync" - - "github.com/GrayCodeAI/tok" ) // CostEstimator estimates the cost of an API call BEFORE sending it. @@ -65,14 +63,14 @@ func (ce *CostEstimator) IsExpensive(est CostEstimate, threshold float64) bool { func (ce *CostEstimator) countInputTokens(messages []EyrieMessage) int { total := 0 for _, m := range messages { - total += tok.EstimateTokens(m.Content) + total += estimateTextTokens(m.Content) for _, tr := range m.ToolResults { - total += tok.EstimateTokens(tr.Content) + total += estimateTextTokens(tr.Content) } for _, tc := range m.ToolUse { total += 50 // tool call overhead for _, v := range tc.Arguments { - total += tok.EstimateTokens(fmt.Sprintf("%v", v)) + total += estimateTextTokens(fmt.Sprintf("%v", v)) } } } @@ -100,7 +98,7 @@ func NewStreamingTokenCounter(model string, inputTokens int) *StreamingTokenCoun // AddOutput records streamed output tokens. func (stc *StreamingTokenCounter) AddOutput(text string) { stc.mu.Lock() - stc.outputTokens += tok.EstimateTokens(text) + stc.outputTokens += estimateTextTokens(text) stc.mu.Unlock() } @@ -157,7 +155,7 @@ func NewPromptOptimizer(maxInputTokens int) *PromptOptimizer { func (po *PromptOptimizer) Optimize(messages []EyrieMessage) []EyrieMessage { totalTokens := 0 for _, m := range messages { - totalTokens += tok.EstimateTokens(m.Content) + 10 // +10 for overhead + totalTokens += estimateTextTokens(m.Content) + 10 // +10 for overhead } if totalTokens <= po.maxInputTokens { @@ -196,8 +194,7 @@ func compressMessages(messages []EyrieMessage) string { } raw := strings.Join(parts, "\n") - // Use tok compression pipeline for intelligent summarization - compressed, _ := tok.Compress(raw, tok.Minimal) + compressed := compressForSummary(raw) if len(compressed) > 0 && len(compressed) < len(raw) { return compressed } diff --git a/client/cost_estimator_test.go b/client/cost_estimator_test.go index 57d8217..b415761 100644 --- a/client/cost_estimator_test.go +++ b/client/cost_estimator_test.go @@ -3,8 +3,6 @@ package client import ( "math" "testing" - - "github.com/GrayCodeAI/tok" ) func TestCostEstimateForKnownModels(t *testing.T) { @@ -73,7 +71,7 @@ func TestCostEstimateUnknownModelReturnsNonZero(t *testing.T) { expectedInPrice := 1.0 / 1_000_000 expectedOutPrice := 3.0 / 1_000_000 - inputTokens := tok.EstimateTokens("test message here") + inputTokens := estimateTextTokens("test message here") expectedInput := float64(inputTokens) * expectedInPrice expectedOutput := float64(1000) * expectedOutPrice diff --git a/client/token_utils.go b/client/token_utils.go new file mode 100644 index 0000000..142c1c8 --- /dev/null +++ b/client/token_utils.go @@ -0,0 +1,85 @@ +package client + +import ( + "regexp" + "strings" + "sync" + "unicode/utf8" + + tiktoken "github.com/tiktoken-go/tokenizer" +) + +var whitespacePattern = regexp.MustCompile(`\s+`) + +var ( + tokenizerOnce sync.Once + tokenizerBPE tiktoken.Codec + tokenizerErr error +) + +// estimateTextTokens uses a lightweight character-based heuristic. Eyrie needs +// cheap local budgeting, not shared cross-repo token infrastructure. +func estimateTextTokens(text string) int { + if count, ok := preciseTokenCount(text); ok { + return count + } + return fallbackTokenCount(text) +} + +func preciseTokenCount(text string) (int, bool) { + if text == "" { + return 0, true + } + + tokenizerOnce.Do(func() { + tokenizerBPE, tokenizerErr = tiktoken.Get(tiktoken.Cl100kBase) + }) + if tokenizerErr != nil || tokenizerBPE == nil { + return 0, false + } + count, err := tokenizerBPE.Count(text) + if err != nil { + return 0, false + } + return count, true +} + +func fallbackTokenCount(text string) int { + if text == "" { + return 0 + } + length := utf8.RuneCountInString(text) + if length < 30 { + return (length + 2) / 3 + } + if length < 100 { + return (length + 3) / 4 + } + + spaces := 0 + sample := len(text) + if sample > 200 { + sample = 200 + } + for i := 0; i < sample; i++ { + switch text[i] { + case ' ', '\n', '\t', '\r': + spaces++ + } + } + + spaceRatio := float64(spaces) / float64(sample) + nonSpaceChars := float64(length) * (1 - spaceRatio) + spaceTokens := float64(length) * spaceRatio + return int(nonSpaceChars/3.5 + spaceTokens) +} + +// compressForSummary keeps PromptOptimizer self-contained. It performs a small +// whitespace-normalizing reduction instead of depending on tok's full pipeline. +func compressForSummary(text string) string { + text = strings.TrimSpace(text) + if text == "" { + return "" + } + return whitespacePattern.ReplaceAllString(text, " ") +} diff --git a/client/tracing.go b/client/tracing.go index 132ad57..96e177d 100644 --- a/client/tracing.go +++ b/client/tracing.go @@ -93,12 +93,14 @@ func (tp *TracingProvider) StreamChat(ctx context.Context, messages []EyrieMessa wrappedEvents := make(chan EyrieStreamEvent, cap(origEvents)) go func() { defer span.End() + defer close(wrappedEvents) for evt := range origEvents { switch evt.Type { case "error": span.SetStatus(codes.Error, evt.Error) span.SetAttributes(attribute.Bool("error", true)) - case "done": + case "usage": + // Token usage is delivered on the "usage" event, not "done". if evt.Usage != nil { span.SetAttributes( attribute.Int("usage.prompt_tokens", evt.Usage.PromptTokens), @@ -106,11 +108,19 @@ func (tp *TracingProvider) StreamChat(ctx context.Context, messages []EyrieMessa attribute.Int("usage.total_tokens", evt.Usage.TotalTokens), ) } + case "done": span.SetStatus(codes.Ok, "") } - wrappedEvents <- evt + // Respect cancellation on the send: if the consumer abandons the + // stream, this goroutine must not block forever forwarding events + // (which would leak the goroutine and keep the span open). + select { + case wrappedEvents <- evt: + case <-ctx.Done(): + sr.Close() + return + } } - close(wrappedEvents) }() return &StreamResult{ diff --git a/client/usage_limit.go b/client/usage_limit.go index d7f0c31..88b7b78 100644 --- a/client/usage_limit.go +++ b/client/usage_limit.go @@ -4,12 +4,10 @@ import ( "context" "errors" "fmt" - - "github.com/GrayCodeAI/tok" ) // UsageLimitProvider wraps any Provider and enforces token/cost budgets -// via a tok.UsageTracker. It calls CanProceed() before each Chat/StreamChat +// via a UsageTracker. It calls CanProceed() before each Chat/StreamChat // request and Record() after successful responses. // // If the budget is exhausted, calls return a non-nil error immediately @@ -19,7 +17,7 @@ import ( // UsageTracker is internally synchronised). type UsageLimitProvider struct { inner Provider - tracker *tok.UsageTracker + tracker *UsageTracker } // Compile-time check that UsageLimitProvider implements Provider. @@ -27,7 +25,7 @@ var _ Provider = (*UsageLimitProvider)(nil) // NewUsageLimitProvider wraps inner with budget enforcement via tracker. // Both arguments must be non-nil; an error is returned otherwise. -func NewUsageLimitProvider(inner Provider, tracker *tok.UsageTracker) (*UsageLimitProvider, error) { +func NewUsageLimitProvider(inner Provider, tracker *UsageTracker) (*UsageLimitProvider, error) { if inner == nil { return nil, errors.New("eyrie: NewUsageLimitProvider inner provider must not be nil") } @@ -43,7 +41,7 @@ func (u *UsageLimitProvider) Name() string { } // Tracker returns the underlying UsageTracker for inspection or configuration. -func (u *UsageLimitProvider) Tracker() *tok.UsageTracker { +func (u *UsageLimitProvider) Tracker() *UsageTracker { return u.tracker } @@ -103,10 +101,7 @@ func (u *UsageLimitProvider) StreamChat(ctx context.Context, messages []EyrieMes } }() - return &StreamResult{ - Events: wrappedCh, - RequestID: result.RequestID, - }, nil + return NewStreamResult(wrappedCh, result.Close), nil } // recordUsage extracts token count from an EyrieResponse and records it. diff --git a/client/usage_tracker.go b/client/usage_tracker.go new file mode 100644 index 0000000..693d156 --- /dev/null +++ b/client/usage_tracker.go @@ -0,0 +1,367 @@ +package client + +import ( + "fmt" + "math" + "strconv" + "strings" + "sync" + "time" +) + +// UsageEntry represents a single recorded usage event. +type UsageEntry struct { + Tokens int + CostUSD float64 + Timestamp time.Time + Provider string + Model string +} + +// Alert represents a usage threshold alert. +type Alert struct { + Level string + Message string + Timestamp time.Time + Threshold float64 +} + +// UsageSummary provides a snapshot of current usage across all windows. +type UsageSummary struct { + HourlyTokens int + HourlyRemaining int + DailyTokens int + DailyRemaining int + SessionTokens int + SessionRemaining int + DailyCostUSD float64 + CostRemaining float64 + HourlyPct float64 + DailyPct float64 +} + +// UsageTracker tracks API usage across sessions and prevents surprise bills. +type UsageTracker struct { + DailyLimit int + HourlyLimit int + SessionLimit int + CostLimitUSD float64 + + hourlyUsage []UsageEntry + dailyUsage []UsageEntry + sessionUsage int + mu sync.Mutex + Alerts []Alert + + firedThresholds map[string]bool +} + +// NewUsageTracker creates a UsageTracker with sensible defaults. +func NewUsageTracker() *UsageTracker { + return &UsageTracker{ + DailyLimit: 1_000_000, + HourlyLimit: 200_000, + SessionLimit: 500_000, + CostLimitUSD: 10.00, + hourlyUsage: make([]UsageEntry, 0), + dailyUsage: make([]UsageEntry, 0), + firedThresholds: make(map[string]bool), + } +} + +func (u *UsageTracker) Record(tokens int, costUSD float64, provider, model string) { + u.mu.Lock() + defer u.mu.Unlock() + + entry := UsageEntry{ + Tokens: tokens, + CostUSD: costUSD, + Timestamp: time.Now(), + Provider: provider, + Model: model, + } + + u.hourlyUsage = append(u.hourlyUsage, entry) + u.dailyUsage = append(u.dailyUsage, entry) + u.sessionUsage += tokens + u.checkThresholdsLocked() +} + +func (u *UsageTracker) CanProceed() (bool, string) { + u.mu.Lock() + defer u.mu.Unlock() + + u.pruneOldLocked() + + hourlyTokens := u.hourlyTokensLocked() + if hourlyTokens >= u.HourlyLimit { + return false, fmt.Sprintf("hourly token limit reached (%d/%d)", hourlyTokens, u.HourlyLimit) + } + + dailyTokens := u.dailyTokensLocked() + if dailyTokens >= u.DailyLimit { + return false, fmt.Sprintf("daily token limit reached (%d/%d)", dailyTokens, u.DailyLimit) + } + + if u.sessionUsage >= u.SessionLimit { + return false, fmt.Sprintf("session token limit reached (%d/%d)", u.sessionUsage, u.SessionLimit) + } + + dailyCost := u.dailyCostLocked() + if dailyCost >= u.CostLimitUSD { + return false, fmt.Sprintf("daily cost limit reached ($%.2f/$%.2f)", dailyCost, u.CostLimitUSD) + } + + return true, "" +} + +func (u *UsageTracker) GetUsage() UsageSummary { + u.mu.Lock() + defer u.mu.Unlock() + + u.pruneOldLocked() + + hourlyTokens := u.hourlyTokensLocked() + dailyTokens := u.dailyTokensLocked() + dailyCost := u.dailyCostLocked() + + hourlyRemaining := max(0, u.HourlyLimit-hourlyTokens) + dailyRemaining := max(0, u.DailyLimit-dailyTokens) + sessionRemaining := max(0, u.SessionLimit-u.sessionUsage) + costRemaining := u.CostLimitUSD - dailyCost + if costRemaining < 0 { + costRemaining = 0 + } + + var hourlyPct, dailyPct float64 + if u.HourlyLimit > 0 { + hourlyPct = float64(hourlyTokens) / float64(u.HourlyLimit) * 100 + } + if u.DailyLimit > 0 { + dailyPct = float64(dailyTokens) / float64(u.DailyLimit) * 100 + } + + return UsageSummary{ + HourlyTokens: hourlyTokens, + HourlyRemaining: hourlyRemaining, + DailyTokens: dailyTokens, + DailyRemaining: dailyRemaining, + SessionTokens: u.sessionUsage, + SessionRemaining: sessionRemaining, + DailyCostUSD: dailyCost, + CostRemaining: costRemaining, + HourlyPct: hourlyPct, + DailyPct: dailyPct, + } +} + +func (u *UsageTracker) CheckThresholds() { + u.mu.Lock() + defer u.mu.Unlock() + u.checkThresholdsLocked() +} + +func (u *UsageTracker) Reset() { + u.mu.Lock() + defer u.mu.Unlock() + + u.sessionUsage = 0 + u.Alerts = nil + u.firedThresholds = make(map[string]bool) +} + +func (u *UsageTracker) PruneOld() { + u.mu.Lock() + defer u.mu.Unlock() + u.pruneOldLocked() +} + +func (u *UsageTracker) EstimateRemaining(tokensPerRequest int) int { + u.mu.Lock() + defer u.mu.Unlock() + + if tokensPerRequest <= 0 { + return 0 + } + + u.pruneOldLocked() + + minRemaining := u.HourlyLimit - u.hourlyTokensLocked() + if dailyRemaining := u.DailyLimit - u.dailyTokensLocked(); dailyRemaining < minRemaining { + minRemaining = dailyRemaining + } + if sessionRemaining := u.SessionLimit - u.sessionUsage; sessionRemaining < minRemaining { + minRemaining = sessionRemaining + } + if minRemaining <= 0 { + return 0 + } + return minRemaining / tokensPerRequest +} + +func FormatUsageBar(pct float64, width int) string { + if width <= 0 { + return "" + } + if pct < 0 { + pct = 0 + } + if pct > 100 { + pct = 100 + } + + filled := int(math.Round(pct / 100 * float64(width))) + if filled > width { + filled = width + } + + bar := strings.Repeat("█", filled) + strings.Repeat("░", width-filled) + return fmt.Sprintf("[%s] %d%%", bar, int(pct)) +} + +func (u *UsageTracker) FormatSummary() string { + summary := u.GetUsage() + + sessionPct := float64(0) + if u.SessionLimit > 0 { + sessionPct = float64(summary.SessionTokens) / float64(u.SessionLimit) * 100 + } + costPct := float64(0) + if u.CostLimitUSD > 0 { + costPct = summary.DailyCostUSD / u.CostLimitUSD * 100 + } + + barWidth := 16 + var sb strings.Builder + sb.WriteString("Token Usage:\n") + fmt.Fprintf(&sb, " Hourly: %s / %s (%d%%) %s\n", + formatNumber(summary.HourlyTokens), + formatNumber(u.HourlyLimit), + int(summary.HourlyPct), + FormatUsageBar(summary.HourlyPct, barWidth)) + fmt.Fprintf(&sb, " Daily: %s / %s (%d%%) %s\n", + formatNumber(summary.DailyTokens), + formatNumber(u.DailyLimit), + int(summary.DailyPct), + FormatUsageBar(summary.DailyPct, barWidth)) + fmt.Fprintf(&sb, " Session: %s / %s (%d%%) %s\n", + formatNumber(summary.SessionTokens), + formatNumber(u.SessionLimit), + int(sessionPct), + FormatUsageBar(sessionPct, barWidth)) + fmt.Fprintf(&sb, " Cost: $%.2f / $%.2f (%d%%) %s", + summary.DailyCostUSD, + u.CostLimitUSD, + int(costPct), + FormatUsageBar(costPct, barWidth)) + + return sb.String() +} + +func (u *UsageTracker) checkThresholdsLocked() { + u.pruneOldLocked() + + if u.HourlyLimit > 0 { + pct := float64(u.hourlyTokensLocked()) / float64(u.HourlyLimit) * 100 + u.emitAlert("hourly", pct, "hourly token usage") + } + if u.DailyLimit > 0 { + pct := float64(u.dailyTokensLocked()) / float64(u.DailyLimit) * 100 + u.emitAlert("daily", pct, "daily token usage") + } + if u.SessionLimit > 0 { + pct := float64(u.sessionUsage) / float64(u.SessionLimit) * 100 + u.emitAlert("session", pct, "session token usage") + } + if u.CostLimitUSD > 0 { + pct := u.dailyCostLocked() / u.CostLimitUSD * 100 + u.emitAlert("cost", pct, "daily cost") + } +} + +func (u *UsageTracker) emitAlert(category string, pct float64, label string) { + type threshold struct { + pct float64 + level string + } + for _, t := range []threshold{{100, "limit_reached"}, {80, "critical"}, {50, "warning"}} { + if pct >= t.pct { + key := fmt.Sprintf("%s_%d", category, int(t.pct)) + if !u.firedThresholds[key] { + u.firedThresholds[key] = true + u.Alerts = append(u.Alerts, Alert{ + Level: t.level, + Message: fmt.Sprintf("%s at %.0f%% of limit", label, pct), + Timestamp: time.Now(), + Threshold: t.pct, + }) + } + return + } + } +} + +func (u *UsageTracker) pruneOldLocked() { + now := time.Now() + hourAgo := now.Add(-1 * time.Hour) + dayAgo := now.Add(-24 * time.Hour) + + prunedHourly := u.hourlyUsage[:0] + for _, e := range u.hourlyUsage { + if !e.Timestamp.Before(hourAgo) { + prunedHourly = append(prunedHourly, e) + } + } + u.hourlyUsage = prunedHourly + + prunedDaily := u.dailyUsage[:0] + for _, e := range u.dailyUsage { + if !e.Timestamp.Before(dayAgo) { + prunedDaily = append(prunedDaily, e) + } + } + u.dailyUsage = prunedDaily +} + +func (u *UsageTracker) hourlyTokensLocked() int { + total := 0 + for _, e := range u.hourlyUsage { + total += e.Tokens + } + return total +} + +func (u *UsageTracker) dailyTokensLocked() int { + total := 0 + for _, e := range u.dailyUsage { + total += e.Tokens + } + return total +} + +func (u *UsageTracker) dailyCostLocked() float64 { + total := 0.0 + for _, e := range u.dailyUsage { + total += e.CostUSD + } + return total +} + +func formatNumber(n int) string { + s := strconv.Itoa(n) + if len(s) <= 3 { + return s + } + + var out []byte + prefix := len(s) % 3 + if prefix == 0 { + prefix = 3 + } + out = append(out, s[:prefix]...) + for i := prefix; i < len(s); i += 3 { + out = append(out, ',') + out = append(out, s[i:i+3]...) + } + return string(out) +} diff --git a/go.mod b/go.mod index 328de84..11e1f57 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/GrayCodeAI/eyrie go 1.26.4 require ( - github.com/GrayCodeAI/tok v0.1.0 github.com/google/uuid v1.6.0 + github.com/tiktoken-go/tokenizer v0.8.0 github.com/zalando/go-keyring v0.2.8 go.opentelemetry.io/otel v1.44.0 go.opentelemetry.io/otel/trace v1.44.0 @@ -12,33 +12,21 @@ require ( ) require ( - github.com/BurntSushi/toml v1.6.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/danieljoos/wincred v1.2.3 // indirect github.com/dlclark/regexp2/v2 v2.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/fsnotify/fsnotify v1.10.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/godbus/dbus/v5 v5.2.2 // indirect github.com/mattn/go-isatty v0.0.22 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/pelletier/go-toml/v2 v2.3.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/sagikazarmark/locafero v0.12.0 // indirect - github.com/spf13/afero v1.15.0 // indirect - github.com/spf13/cast v1.10.0 // indirect - github.com/spf13/pflag v1.0.10 // indirect - github.com/spf13/viper v1.21.0 // indirect github.com/stretchr/objx v0.5.3 // indirect - github.com/subosito/gotenv v1.6.0 // indirect - github.com/tiktoken-go/tokenizer v0.8.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/metric v1.44.0 // indirect - go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/sys v0.45.0 // indirect - golang.org/x/text v0.37.0 // indirect + golang.org/x/tools v0.44.0 // indirect modernc.org/libc v1.72.5 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/go.sum b/go.sum index fe8e8a9..257f8ba 100644 --- a/go.sum +++ b/go.sum @@ -1,70 +1,40 @@ -github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= -github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/GrayCodeAI/tok v0.1.0 h1:6lhxIGg1eDsnOtAuGOZf803aqj4CrPmVmTwKRw25Zio= -github.com/GrayCodeAI/tok v0.1.0/go.mod h1:oqA7HXbXuyrZ3+uJC+TKJWmYYPlyShaXGDQpftEJ9OE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ= github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2/v2 v2.1.0 h1:jHXRmHRZGbuQzDZjMlCAXOvQb75iv3HyLDzXGj5H1AY= github.com/dlclark/regexp2/v2 v2.1.0/go.mod h1:Bz5TMy5d8fPK0ximH0Yi9KvsRHNnvXqUx9XG6a4wB+I= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho= -github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= -github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= -github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= -github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= -github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= -github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= -github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= -github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= -github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= -github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= -github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= -github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= -github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= -github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tiktoken-go/tokenizer v0.8.0 h1:drHWno2Zx3eAm/hk/LmvBKXPpSImB7BRyh/ru4+3Q7Y= github.com/tiktoken-go/tokenizer v0.8.0/go.mod h1:pTmPz4r14MV3JkUGAmAcdLdYhSxN68MCjrP+EoxBdx0= github.com/zalando/go-keyring v0.2.8 h1:6sD/Ucpl7jNq10rM2pgqTs0sZ9V3qMrqfIIy5YPccHs= @@ -77,21 +47,14 @@ go.opentelemetry.io/otel/metric v1.44.0 h1:1w0gILTcHdr3YI+ixLyjemwrVnsMURbTZFrSY go.opentelemetry.io/otel/metric v1.44.0/go.mod h1:8O7hanEPBNgEMmybD3s2VBKcgWOCsA6tzHBPODAiquo= go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/6gtIk= go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE= -go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= -go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= -golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= diff --git a/go.work b/go.work index b7f7ecf..4b21626 100644 --- a/go.work +++ b/go.work @@ -1,5 +1,3 @@ go 1.26.4 use . - -replace github.com/GrayCodeAI/tok => ../tok diff --git a/internal/api/rerank.go b/internal/api/rerank.go index cf9265e..ddfbff9 100644 --- a/internal/api/rerank.go +++ b/internal/api/rerank.go @@ -15,8 +15,7 @@ import ( // When no provider-backed reranker is configured (the default), eyrie falls // back to a zero-dependency lexical scorer (cosine similarity over // term-frequency vectors) so the endpoint is always functional. A -// provider-backed path (e.g. Cohere rerank) can be wired in later via the -// Reranker interface below; see the TODO on Server.reranker. +// provider-backed path (e.g. Cohere rerank) can be injected via Config.Reranker. // rerankRequest is the request body for POST /rerank. It mirrors the common // Cohere/LiteLLM rerank shape. @@ -40,14 +39,10 @@ type rerankResponse struct { Results []rerankResult `json:"results"` } -// Reranker is the provider-backed reranking interface. A real implementation -// (e.g. backed by Cohere's /rerank API or a cross-encoder model) can be -// injected so /rerank uses model-quality scores instead of the local lexical -// fallback. -// -// TODO(provider-rerank): wire a concrete Reranker into Server (via Config) and -// prefer it in handleRerank when non-nil. Until then s.reranker stays nil and -// the lexical fallback is used. +// Reranker is the provider-backed reranking interface. A concrete +// implementation (e.g. backed by Cohere's /rerank API or a cross-encoder model) +// can be injected so /rerank uses model-quality scores instead of the local +// lexical fallback. type Reranker interface { // Rerank returns relevance scores in [0,1] for each document, in the same // order as the input documents. @@ -78,6 +73,10 @@ func (s *Server) handleRerank(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } + if len(scores) != len(req.Documents) { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "reranker returned invalid score count"}) + return + } } else { // Zero-dependency lexical fallback. scores = lexicalRerankScores(req.Query, req.Documents) diff --git a/internal/api/rerank_test.go b/internal/api/rerank_test.go index 7b51288..94d9751 100644 --- a/internal/api/rerank_test.go +++ b/internal/api/rerank_test.go @@ -2,13 +2,48 @@ package api import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" + "path/filepath" + "reflect" "strings" "testing" + + "github.com/GrayCodeAI/eyrie/storage" ) +type mockReranker struct { + scores []float64 + err error + model string + query string + documents []string +} + +func (m *mockReranker) Rerank(_ context.Context, model, query string, documents []string) ([]float64, error) { + m.model = model + m.query = query + m.documents = append([]string(nil), documents...) + if m.err != nil { + return nil, m.err + } + return append([]float64(nil), m.scores...), nil +} + +func testServerWithReranker(t *testing.T, r Reranker) *httptest.Server { + t.Helper() + store, err := storage.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = store.Close() }) + srv := NewServer(Config{Store: store, Provider: &mockProv{}, Reranker: r}) + return httptest.NewServer(srv) +} + func TestRerankLexicalFallback(t *testing.T) { ts := testServer(t) defer ts.Close() @@ -46,6 +81,70 @@ func TestRerankLexicalFallback(t *testing.T) { } } +func TestRerankUsesConfiguredReranker(t *testing.T) { + reranker := &mockReranker{scores: []float64{0.1, 0.9, 0.4}} + ts := testServerWithReranker(t, reranker) + defer ts.Close() + + body := `{"model":"rerank-test","query":"cats","documents":["cats purr","dogs bark","birds fly"]}` + resp, err := http.Post(ts.URL+"/rerank", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer drainBody(t, resp) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + var out rerankResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + t.Fatal(err) + } + if out.Model != "rerank-test" { + t.Fatalf("model = %q", out.Model) + } + gotOrder := []int{out.Results[0].Index, out.Results[1].Index, out.Results[2].Index} + if want := []int{1, 2, 0}; !reflect.DeepEqual(gotOrder, want) { + t.Fatalf("order = %v, want %v", gotOrder, want) + } + if reranker.model != "rerank-test" || reranker.query != "cats" { + t.Fatalf("reranker saw model/query = %q/%q", reranker.model, reranker.query) + } + if want := []string{"cats purr", "dogs bark", "birds fly"}; !reflect.DeepEqual(reranker.documents, want) { + t.Fatalf("reranker documents = %v, want %v", reranker.documents, want) + } +} + +func TestRerankConfiguredRerankerError(t *testing.T) { + ts := testServerWithReranker(t, &mockReranker{err: errors.New("rerank failed")}) + defer ts.Close() + + body := `{"query":"cats","documents":["cats purr"]}` + resp, err := http.Post(ts.URL+"/rerank", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer drainBody(t, resp) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", resp.StatusCode) + } +} + +func TestRerankConfiguredRerankerScoreLengthMismatch(t *testing.T) { + ts := testServerWithReranker(t, &mockReranker{scores: []float64{0.2}}) + defer ts.Close() + + body := `{"query":"cats","documents":["cats purr","dogs bark"]}` + resp, err := http.Post(ts.URL+"/rerank", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer drainBody(t, resp) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", resp.StatusCode) + } +} + func TestRerankTopN(t *testing.T) { ts := testServer(t) defer ts.Close() diff --git a/internal/api/server.go b/internal/api/server.go index 95f2297..008b0e7 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -36,6 +36,7 @@ type Config struct { Analytics storage.AnalyticsStore // optional: enables /api/usage, /api/costs Provider client.Provider HealthChecker *eyrie.HealthChecker // optional: enables /api/health/providers + Reranker Reranker // optional: provider-backed /rerank; nil => lexical fallback APIKey string Port int // VirtualKeyResolver optionally maps an inbound bearer/API-key token to a @@ -52,6 +53,7 @@ func NewServer(cfg Config) *Server { store: cfg.Store, analytics: cfg.Analytics, healthChecker: cfg.HealthChecker, + reranker: cfg.Reranker, apiKey: cfg.APIKey, virtualKeyFor: cfg.VirtualKeyResolver, mux: http.NewServeMux(), diff --git a/scripts/check-ecosystem-boundaries.sh b/scripts/check-ecosystem-boundaries.sh new file mode 100644 index 0000000..57f81ba --- /dev/null +++ b/scripts/check-ecosystem-boundaries.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +if command -v rg >/dev/null 2>&1; then + violations="$(rg -n 'github\.com/GrayCodeAI/hawk/(internal/|shared/types)' --glob '*.go' . || true)" +else + violations="$(grep -rn --include='*.go' -E 'github\.com/GrayCodeAI/hawk/(internal/|shared/types)' . || true)" +fi + +if [[ -n "${violations}" ]]; then + echo "forbidden Hawk imports found:" + echo "${violations}" + echo + echo "support repos must use hawk-core-contracts or local contracts, not hawk/internal or removed hawk/shared/types" + exit 1 +fi + +echo "ecosystem boundary guard passed"