From 7213462bb27bcf79c19b45d634121a5ddedff6f9 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 13:55:54 +0800 Subject: [PATCH 1/9] feat(api): add abort signal support to all completePrompt --- src/api/index.ts | 11 +- .../__tests__/anthropic-vertex.spec.ts | 91 ++++++++-- src/api/providers/__tests__/anthropic.spec.ts | 99 ++++++++++- ...openai-compatible-provider-timeout.spec.ts | 54 ++++++ src/api/providers/__tests__/bedrock.spec.ts | 101 +++++++++++ .../__tests__/complete-prompt-options.spec.ts | 29 ++++ src/api/providers/__tests__/deepseek.spec.ts | 41 +++++ src/api/providers/__tests__/fireworks.spec.ts | 35 ++++ .../__tests__/gemini-handler.spec.ts | 38 +++++ src/api/providers/__tests__/gemini.spec.ts | 57 +++++++ src/api/providers/__tests__/lite-llm.spec.ts | 37 ++++ .../__tests__/lmstudio-native-tools.spec.ts | 51 ++++++ src/api/providers/__tests__/lmstudio.spec.ts | 131 +++++++++++++- src/api/providers/__tests__/mimo.spec.ts | 31 ++++ src/api/providers/__tests__/minimax.spec.ts | 57 +++++++ src/api/providers/__tests__/mistral.spec.ts | 65 ++++++- src/api/providers/__tests__/moonshot.spec.ts | 67 ++++++++ .../providers/__tests__/native-ollama.spec.ts | 49 ++++++ .../openai-codex-native-tool-calls.spec.ts | 82 +++++++++ .../providers/__tests__/openai-native.spec.ts | 101 +++++++++++ src/api/providers/__tests__/openai.spec.ts | 39 +++++ .../providers/__tests__/opencode-go.spec.ts | 126 +++++++++++++- .../providers/__tests__/openrouter.spec.ts | 43 +++++ src/api/providers/__tests__/poe.spec.ts | 116 +++++++++++++ .../__tests__/qwen-code-native-tools.spec.ts | 52 ++++++ src/api/providers/__tests__/requesty.spec.ts | 59 +++++-- src/api/providers/__tests__/sambanova.spec.ts | 25 +++ src/api/providers/__tests__/unbound.spec.ts | 53 ++++++ .../__tests__/vercel-ai-gateway.spec.ts | 73 ++++++++ src/api/providers/__tests__/vertex.spec.ts | 29 ++++ src/api/providers/__tests__/vscode-lm.spec.ts | 160 +++++++++++++++++- src/api/providers/__tests__/xai.spec.ts | 51 ++++++ src/api/providers/__tests__/zai.spec.ts | 25 +++ .../providers/__tests__/zoo-gateway.spec.ts | 38 +++++ src/api/providers/anthropic-vertex.ts | 17 +- src/api/providers/anthropic.ts | 32 ++-- .../base-openai-compatible-provider.ts | 15 +- src/api/providers/bedrock.ts | 40 ++++- src/api/providers/fake-ai.ts | 13 +- src/api/providers/gemini.ts | 21 ++- src/api/providers/lite-llm.ts | 15 +- src/api/providers/lm-studio.ts | 15 +- src/api/providers/minimax.ts | 30 +++- src/api/providers/mistral.ts | 26 ++- src/api/providers/native-ollama.ts | 5 +- src/api/providers/openai-codex.ts | 39 ++++- src/api/providers/openai-compatible.ts | 52 +++++- src/api/providers/openai-native.ts | 18 +- src/api/providers/openai.ts | 13 +- src/api/providers/opencode-go.ts | 55 ++++-- src/api/providers/openrouter.ts | 15 +- src/api/providers/poe.ts | 31 +++- src/api/providers/qwen-code.ts | 18 +- src/api/providers/requesty.ts | 14 +- src/api/providers/unbound.ts | 14 +- src/api/providers/vercel-ai-gateway.ts | 17 +- src/api/providers/vscode-lm.ts | 42 ++++- src/api/providers/xai.ts | 26 ++- src/api/providers/zoo-gateway.ts | 15 +- src/core/task/__tests__/Task.spec.ts | 97 ++++++++++- src/utils/__tests__/enhance-prompt.spec.ts | 4 +- src/utils/single-completion-handler.ts | 10 +- 62 files changed, 2645 insertions(+), 180 deletions(-) create mode 100644 src/api/providers/__tests__/complete-prompt-options.spec.ts diff --git a/src/api/index.ts b/src/api/index.ts index 9e4ba3bfb5..228b207179 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -40,8 +40,17 @@ import { } from "./providers" import { NativeOllamaHandler } from "./providers/native-ollama" +/** + * Options for completePrompt — unified with ApiHandlerCreateMessageMetadata. + * Uses abortSignal (not signal) to match the metadata pattern used in stream path. + */ +export interface CompletePromptOptions extends Pick { + /** Optional timeout override (ms) — falls back to provider default if omitted */ + timeoutMs?: number +} + export interface SingleCompletionHandler { - completePrompt(prompt: string): Promise + completePrompt(prompt: string, metadata?: CompletePromptOptions): Promise } export interface ApiHandlerCreateMessageMetadata { diff --git a/src/api/providers/__tests__/anthropic-vertex.spec.ts b/src/api/providers/__tests__/anthropic-vertex.spec.ts index 6b56c5af98..16dd0e8235 100644 --- a/src/api/providers/__tests__/anthropic-vertex.spec.ts +++ b/src/api/providers/__tests__/anthropic-vertex.spec.ts @@ -834,18 +834,22 @@ describe("VertexHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(handler["client"].messages.create).toHaveBeenCalledWith({ - model: "claude-3-5-sonnet-v2@20241022", - max_tokens: 8192, - temperature: 0, - messages: [ - { - role: "user", - content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }], - }, - ], - stream: false, - }) + expect(handler["client"].messages.create).toHaveBeenCalledWith( + { + model: "claude-3-5-sonnet-v2@20241022", + max_tokens: 8192, + temperature: 0, + messages: [ + { + role: "user", + content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }], + }, + ], + stream: false, + thinking: undefined, + }, + undefined, + ) }) it("should handle API errors for Claude", async () => { @@ -895,6 +899,69 @@ describe("VertexHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should pass abort signal through to client", async () => { + handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const controller = new AbortController() + const mockCreate = vitest.fn().mockResolvedValue({ + content: [{ type: "text", text: "response" }], + }) + ;(handler["client"].messages as any).create = mockCreate + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + }) + }) + + it("should work without options (backward compatible)", async () => { + handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const mockCreate = vitest.fn().mockResolvedValue({ + content: [{ type: "text", text: "response" }], + }) + ;(handler["client"].messages as any).create = mockCreate + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), undefined) + }) + + it("completePrompt should pass signal through to client", async () => { + const controller = new AbortController() + const mockCreate = vitest.fn().mockResolvedValue({ + content: [{ type: "text", text: "response" }], + }) + ;(handler["client"].messages as any).create = mockCreate + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal, timeout: 5000 }), + ) + }) + + it("completePrompt should pass timeoutMs when provided", async () => { + const mockCreate = vitest.fn().mockResolvedValue({ + content: [{ type: "text", text: "response" }], + }) + ;(handler["client"].messages as any).create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 3000 }), + ) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/anthropic.spec.ts b/src/api/providers/__tests__/anthropic.spec.ts index 2b944b7db5..d89e2649ab 100644 --- a/src/api/providers/__tests__/anthropic.spec.ts +++ b/src/api/providers/__tests__/anthropic.spec.ts @@ -461,14 +461,17 @@ describe("AnthropicHandler", () => { it("should complete prompt successfully", async () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: [{ role: "user", content: "Test prompt" }], - max_tokens: 8192, - temperature: 0, - thinking: undefined, - stream: false, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.apiModelId, + messages: [{ role: "user", content: "Test prompt" }], + max_tokens: 8192, + temperature: 0, + thinking: undefined, + stream: false, + }, + undefined, + ) }) it("should handle API errors", async () => { @@ -491,6 +494,86 @@ describe("AnthropicHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.apiModelId, + messages: [{ role: "user", content: "test prompt" }], + max_tokens: 8192, + temperature: 0, + thinking: undefined, + stream: false, + }, + { signal: controller.signal }, + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.apiModelId, + messages: [{ role: "user", content: "test prompt" }], + max_tokens: 8192, + temperature: 0, + thinking: undefined, + stream: false, + }, + undefined, + ) + }) + + it("should merge signal and timeout together", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.apiModelId, + messages: [{ role: "user", content: "test prompt" }], + max_tokens: 8192, + temperature: 0, + thinking: undefined, + stream: false, + }, + expect.objectContaining({ signal: controller.signal, timeout: 10000 }), + ) + }) + + it("should pass timeoutMs through to client alongside abortSignal", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: mockOptions.apiModelId }), + expect.objectContaining({ signal: controller.signal, timeout: 5000 }), + ) + }) + + it("should pass the same signal instance", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: controller.signal }), + ) + // Verify it's the exact same instance, not just equal + const callOptions = mockCreate.mock.calls[0][1] + expect(callOptions?.signal).toBe(controller.signal) + }) + + it("should not include signal-related options when not provided", async () => { + mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] }) + await handler.completePrompt("test prompt") + expect(mockCreate).toHaveBeenCalledWith(expect.any(Object), undefined) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts index 6b0c0dca31..bea8b496d7 100644 --- a/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts +++ b/src/api/providers/__tests__/base-openai-compatible-provider-timeout.spec.ts @@ -116,4 +116,58 @@ describe("BaseOpenAiCompatibleProvider Timeout Configuration", () => { }), ) }) + + describe("completePrompt", () => { + it("should pass timeout through to client when both signal and timeoutMs provided", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + const controller = new AbortController() + const mockCreate = vitest.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + expect.objectContaining({ signal: expect.any(AbortSignal), timeout: 5000 }), + ) + }) + + it("should pass only timeoutMs when no signal provided", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + const mockCreate = vitest.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: "test-model" }), { timeout: 3000 }) + }) + + it("should handle timeoutMs=0 as valid value (!== undefined check)", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + const mockCreate = vitest.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: "test-model" }), { timeout: 0 }) + }) + + it("should work without options (backward compatible)", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + const mockCreate = vitest.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + handler["client"].chat.completions.create = mockCreate + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + {}, // empty object when no options + ) + }) + }) }) diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index e2b17e4e87..496f49b15d 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -1639,6 +1639,107 @@ describe("AwsBedrockHandler", () => { expect(isAdaptiveThinkingModel("anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe(false) expect(isAdaptiveThinkingModel("amazon.nova-lite-v1:0")).toBe(false) }) + + it("should pass abort signal through to client.send", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + // Set up the mock on the handler's client instance directly + const clientInstance = (handler as any).client + expect(clientInstance).toBeDefined() + clientInstance.send = mockSend + + const controller = new AbortController() + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null }, + }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockSend).toHaveBeenCalledWith(expect.any(Object), { abortSignal: controller.signal }) + }) + + it("should work without options (backward compatible)", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + expect(clientInstance).toBeDefined() + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null }, + }) + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("response") + expect(mockSend).toHaveBeenCalledWith(expect.any(Object), undefined) + }) + + it("completePrompt should pass timeoutMs through to client", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + expect(clientInstance).toBeDefined() + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null }, + }) + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + + expect(mockSend).toHaveBeenCalled() + // Verify the second argument (sendOptions) contains an abortSignal derived from timeoutMs + const sendOptions = mockSend.mock.calls[0][1] + expect(sendOptions).toBeDefined() + expect(sendOptions?.abortSignal).toBeDefined() + }) + + it("completePrompt should merge abortSignal and timeoutMs", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null }, + }) + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + + expect(mockSend).toHaveBeenCalled() + const sendOptions = mockSend.mock.calls[0][1] + expect(sendOptions?.abortSignal).toBeDefined() + }) }) }) }) diff --git a/src/api/providers/__tests__/complete-prompt-options.spec.ts b/src/api/providers/__tests__/complete-prompt-options.spec.ts new file mode 100644 index 0000000000..f9925cd119 --- /dev/null +++ b/src/api/providers/__tests__/complete-prompt-options.spec.ts @@ -0,0 +1,29 @@ +import { describe, it, expect } from "vitest" + +import type { CompletePromptOptions } from "../../index" + +describe("CompletePromptOptions", () => { + it("should allow abortSignal property", () => { + const controller = new AbortController() + const options: CompletePromptOptions = { abortSignal: controller.signal } + expect(options.abortSignal).toBe(controller.signal) + }) + + it("should allow timeoutMs property", () => { + const options: CompletePromptOptions = { timeoutMs: 5000 } + expect(options.timeoutMs).toBe(5000) + }) + + it("should allow both abortSignal and timeoutMs together", () => { + const controller = new AbortController() + const options: CompletePromptOptions = { abortSignal: controller.signal, timeoutMs: 10000 } + expect(options.abortSignal).toBe(controller.signal) + expect(options.timeoutMs).toBe(10000) + }) + + it("should allow empty options object", () => { + const options: CompletePromptOptions = {} + expect(options.abortSignal).toBeUndefined() + expect(options.timeoutMs).toBeUndefined() + }) +}) diff --git a/src/api/providers/__tests__/deepseek.spec.ts b/src/api/providers/__tests__/deepseek.spec.ts index 2f0482eeef..60d68d495e 100644 --- a/src/api/providers/__tests__/deepseek.spec.ts +++ b/src/api/providers/__tests__/deepseek.spec.ts @@ -710,4 +710,45 @@ describe("DeepSeekHandler", () => { expect(toolCallChunks[0].name).toBe("get_weather") }) }) + + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + }) }) diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts index 33d50ab7b2..21efd89c06 100644 --- a/src/api/providers/__tests__/fireworks.spec.ts +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -609,6 +609,41 @@ describe("FireworksHandler", () => { expect(result).toBe("") }) + it("completePrompt should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("completePrompt should merge signal and timeoutMs together", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal, timeout: 10000 }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + it("createMessage should handle stream with multiple chunks", async () => { mockCreate.mockImplementationOnce(async () => ({ [Symbol.asyncIterator]: async function* () { diff --git a/src/api/providers/__tests__/gemini-handler.spec.ts b/src/api/providers/__tests__/gemini-handler.spec.ts index 110f60289c..364f62e23d 100644 --- a/src/api/providers/__tests__/gemini-handler.spec.ts +++ b/src/api/providers/__tests__/gemini-handler.spec.ts @@ -55,6 +55,44 @@ describe("GeminiHandler backend support", () => { expect(promptConfig.tools).toBeUndefined() }) + it("completePrompt should pass abort signal through to client via httpOptions", async () => { + const options = { + apiProvider: "gemini", + enableUrlContext: false, + enableGrounding: false, + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + + const controller = new AbortController() + const stub = vi.fn().mockResolvedValue({ text: "response" }) + handler["client"].models.generateContent = stub + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(stub).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + abortSignal: controller.signal, + }), + }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + const options = { + apiProvider: "gemini", + enableUrlContext: false, + enableGrounding: false, + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + + const stub = vi.fn().mockResolvedValue({ text: "response" }) + handler["client"].models.generateContent = stub + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + describe("error scenarios", () => { it("should handle grounding metadata extraction failure gracefully", async () => { const options = { diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 76b0b3abbe..f14293bfb9 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -365,6 +365,63 @@ describe("GeminiHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should pass abort signal through to client via config.abortSignal", async () => { + const controller = new AbortController() + ;(handler["client"].models.generateContent as any).mockResolvedValue({ text: "response" }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ + model: GEMINI_MODEL_NAME, + contents: [{ role: "user", parts: [{ text: "test prompt" }] }], + config: { + abortSignal: controller.signal, + httpOptions: undefined, + temperature: 1, + }, + }) + }) + + it("should work without options (backward compatible)", async () => { + ;(handler["client"].models.generateContent as any).mockResolvedValue({ text: "response" }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ + model: GEMINI_MODEL_NAME, + contents: [{ role: "user", parts: [{ text: "test prompt" }] }], + config: { + httpOptions: undefined, + temperature: 1, + }, + }) + }) + + it("should pass timeoutMs through to client via httpOptions with abortSignal on config", async () => { + const controller = new AbortController() + ;(handler["client"].models.generateContent as any).mockResolvedValue({ text: "response" }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ + model: GEMINI_MODEL_NAME, + contents: [{ role: "user", parts: [{ text: "test prompt" }] }], + config: { + abortSignal: controller.signal, + httpOptions: { timeout: 10000 }, + temperature: 1, + }, + }) + }) + + it("should pass only timeoutMs when no signal is provided", async () => { + ;(handler["client"].models.generateContent as any).mockResolvedValue({ text: "response" }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ + model: GEMINI_MODEL_NAME, + contents: [{ role: "user", parts: [{ text: "test prompt" }] }], + config: { + httpOptions: { timeout: 5000 }, + temperature: 1, + }, + }) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts index ab2f261058..1c0ddbda52 100644 --- a/src/api/providers/__tests__/lite-llm.spec.ts +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -1180,4 +1180,41 @@ describe("LiteLLMHandler", () => { expect(requestHeaders).not.toHaveProperty("X-Zoo-Session-ID") }) }) + + describe("completePrompt", () => { + it("should pass abort signal through to client", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should merge signal and timeoutMs together", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal, timeout: 10000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + }) }) diff --git a/src/api/providers/__tests__/lmstudio-native-tools.spec.ts b/src/api/providers/__tests__/lmstudio-native-tools.spec.ts index c6870e80ef..5bce358d29 100644 --- a/src/api/providers/__tests__/lmstudio-native-tools.spec.ts +++ b/src/api/providers/__tests__/lmstudio-native-tools.spec.ts @@ -376,4 +376,55 @@ describe("LmStudioHandler Native Tools", () => { expect(endChunks).toHaveLength(1) }) }) + + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "This is a test response" } }], + }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("This is a test response") + }) + + it("should handle errors in completePrompt", async () => { + mockCreate.mockRejectedValueOnce(new Error("API error")) + + await expect(handler.completePrompt("test prompt")).rejects.toThrow() + }) + + it("completePrompt should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + }) }) diff --git a/src/api/providers/__tests__/lmstudio.spec.ts b/src/api/providers/__tests__/lmstudio.spec.ts index c6ebd8a6e9..56511f5c28 100644 --- a/src/api/providers/__tests__/lmstudio.spec.ts +++ b/src/api/providers/__tests__/lmstudio.spec.ts @@ -133,12 +133,15 @@ describe("LmStudioHandler", () => { it("should complete prompt successfully", async () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.lmStudioModelId, - messages: [{ role: "user", content: "Test prompt" }], - temperature: 0, - stream: false, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.lmStudioModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + stream: false, + }, + {}, + ) }) it("should handle API errors", async () => { @@ -155,6 +158,60 @@ describe("LmStudioHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should pass timeoutMs=0 through to client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 0 }), + ) + }) + + it("should merge signal and timeoutMs together", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal, timeout: 10000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) describe("getModel", () => { @@ -166,4 +223,66 @@ describe("LmStudioHandler", () => { expect(modelInfo.info.contextWindow).toBe(128_000) }) }) + + describe("speculative decoding", () => { + it("should include draft_model in completePrompt when speculative decoding is enabled", async () => { + const handlerWithSpeculative = new LmStudioHandler({ + apiModelId: "local-model", + lmStudioModelId: "local-model", + lmStudioBaseUrl: "http://localhost:1234", + lmStudioSpeculativeDecodingEnabled: true, + lmStudioDraftModelId: "draft-model", + }) + + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handlerWithSpeculative.completePrompt("test prompt") + + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ draft_model: "draft-model" }), {}) + }) + + it("should not include draft_model when speculative decoding is disabled", async () => { + const handlerWithoutSpeculative = new LmStudioHandler({ + apiModelId: "local-model", + lmStudioModelId: "local-model", + lmStudioBaseUrl: "http://localhost:1234", + lmStudioSpeculativeDecodingEnabled: false, + lmStudioDraftModelId: "draft-model", + }) + + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handlerWithoutSpeculative.completePrompt("test prompt") + + // Verify draft_model is NOT in the params when speculative decoding is disabled + const calledParams = mockCreate.mock.calls[0][0] as Record + expect(calledParams.model).toBe("local-model") + expect(calledParams).not.toHaveProperty("draft_model") + }) + + it("should not include draft_model when draft model id is empty", async () => { + const handlerEmptyDraft = new LmStudioHandler({ + apiModelId: "local-model", + lmStudioModelId: "local-model", + lmStudioBaseUrl: "http://localhost:1234", + lmStudioSpeculativeDecodingEnabled: true, + lmStudioDraftModelId: "", + }) + + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handlerEmptyDraft.completePrompt("test prompt") + + // Verify draft_model is NOT in the params when draft model id is empty + const calledParamsEmpty = mockCreate.mock.calls[0][0] as Record + expect(calledParamsEmpty.model).toBe("local-model") + expect(calledParamsEmpty).not.toHaveProperty("draft_model") + }) + }) }) diff --git a/src/api/providers/__tests__/mimo.spec.ts b/src/api/providers/__tests__/mimo.spec.ts index 7da1c84463..c8ddc4d73c 100644 --- a/src/api/providers/__tests__/mimo.spec.ts +++ b/src/api/providers/__tests__/mimo.spec.ts @@ -998,5 +998,36 @@ describe("MimoHandler", () => { const params = mockCreate.mock.calls[0][0] expect(params.model).toBe("mimo-v2.5") }) + + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) }) diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index d87ae1190b..ad8eb4d529 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -220,6 +220,63 @@ describe("MiniMaxHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow() }) + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + content: [{ type: "text", text: "response" }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + { signal: controller.signal }, // second arg (options) + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ + content: [{ type: "text", text: "response" }], + }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + undefined, // second arg (options) + ) + }) + + it("should pass timeout through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + content: [{ type: "text", text: "response" }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + timeout: 5000, + }) + }) + + it("should pass only timeoutMs when no signal provided", async () => { + mockCreate.mockResolvedValueOnce({ + content: [{ type: "text", text: "response" }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + timeout: 3000, + }) + }) + + it("should pass timeout when timeoutMs=0 (defined check)", async () => { + mockCreate.mockResolvedValueOnce({ + content: [{ type: "text", text: "response" }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + { timeout: 0 }, // !== undefined check means 0 is passed through + ) + }) + it("createMessage should yield text content from stream", async () => { const testContent = "This is test content from MiniMax stream" diff --git a/src/api/providers/__tests__/mistral.spec.ts b/src/api/providers/__tests__/mistral.spec.ts index 96e42e356b..db278e721c 100644 --- a/src/api/providers/__tests__/mistral.spec.ts +++ b/src/api/providers/__tests__/mistral.spec.ts @@ -461,11 +461,14 @@ describe("MistralHandler", () => { const prompt = "Test prompt" const result = await handler.completePrompt(prompt) - expect(mockComplete).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: [{ role: "user", content: prompt }], - temperature: 0, - }) + expect(mockComplete).toHaveBeenCalledWith( + { + model: mockOptions.apiModelId, + messages: [{ role: "user", content: prompt }], + temperature: 0, + }, + undefined, + ) expect(result).toBe("Test response") }) @@ -497,5 +500,57 @@ describe("MistralHandler", () => { mockComplete.mockRejectedValueOnce(new Error("API Error")) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Mistral completion error: API Error") }) + + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockComplete.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockComplete).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + fetchOptions: { signal: controller.signal }, + }) + }) + + it("should work without options (backward compatible)", async () => { + mockComplete.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockComplete).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), undefined) + }) + + it("should pass timeout through to client", async () => { + const controller = new AbortController() + mockComplete.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockComplete).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + fetchOptions: { signal: controller.signal }, + timeoutMs: 5000, + }) + }) + + it("should pass only timeoutMs when no signal provided", async () => { + mockComplete.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockComplete).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + timeoutMs: 3000, + }) + }) + + it("should not set timeout when timeoutMs=0 (truthy check)", async () => { + mockComplete.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockComplete).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + timeoutMs: 0, + }) + }) }) }) diff --git a/src/api/providers/__tests__/moonshot.spec.ts b/src/api/providers/__tests__/moonshot.spec.ts index c0fd832a19..6203d9f772 100644 --- a/src/api/providers/__tests__/moonshot.spec.ts +++ b/src/api/providers/__tests__/moonshot.spec.ts @@ -238,6 +238,73 @@ describe("MoonshotHandler", () => { }), ) }) + + it("should pass abort signal through to generateText", async () => { + const controller = new AbortController() + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "test prompt", + abortSignal: controller.signal, + }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + + it("should merge signal and timeoutMs into combined abortSignal", async () => { + const controller = new AbortController() + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "test prompt", + abortSignal: expect.any(AbortSignal), + }), + ) + const callArgs = mockGenerateText.mock.calls[0][0] + expect(callArgs.abortSignal).toBeDefined() + expect(callArgs.abortSignal).toBeInstanceOf(AbortSignal) + }) + + it("should use AbortSignal.timeout when only timeoutMs is provided", async () => { + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "test prompt", + abortSignal: expect.any(AbortSignal), + }), + ) + }) + + it("should set an immediately-aborted abortSignal when timeoutMs is 0", async () => { + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "test prompt", + }), + ) + const callArgs = mockGenerateText.mock.calls[0][0] + // With !== undefined check, timeoutMs=0 creates an immediately-aborted signal + expect(callArgs.abortSignal).toBeDefined() + expect(callArgs.abortSignal.aborted).toBe(true) + }) + + it("should propagate errors from generateText", async () => { + mockGenerateText.mockRejectedValueOnce(new Error("API error")) + + await expect(handler.completePrompt("test prompt")).rejects.toThrow("API error") + }) }) describe("processUsageMetrics", () => { diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts index 8a6b025c8f..2df9155544 100644 --- a/src/api/providers/__tests__/native-ollama.spec.ts +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -317,6 +317,55 @@ describe("NativeOllamaHandler", () => { }), ) }) + + it("should accept options param but ignore it (no signal support)", async () => { + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + // Verify that the call does NOT include any signal-related options + // Ollama implementation only passes the payload, not a second options argument + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + model: "llama2", + messages: [{ role: "user", content: "Test prompt" }], + stream: false, + options: { temperature: 0 }, + }), + ) + // Verify no second argument was passed (no signal/options forwarded) + expect(mockChat).toHaveBeenCalledTimes(1) + expect(mockChat.mock.calls[0]).toHaveLength(1) + }) + + it("should not include signal-related options when not provided", async () => { + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + await handler.completePrompt("Test prompt") + + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + model: "llama2", + messages: [{ role: "user", content: "Test prompt" }], + stream: false, + options: { temperature: 0 }, + }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Response") + }) }) describe("error handling", () => { diff --git a/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts b/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts index 80ab4e1887..ecaba2704b 100644 --- a/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts +++ b/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts @@ -527,4 +527,86 @@ describe("OpenAiCodexHandler native tool calls", () => { }), ) }) + + it("completePrompt should pass abort signal through to fetch", async () => { + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "done" }], + }, + ], + }), + }) + global.fetch = mockFetch as any + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + const fetchCallArgs = mockFetch.mock.calls[0] + // The implementation merges signals using AbortSignal.any(), + // which creates a new merged signal when both primary and secondary are provided. + // The merged signal should abort when the user's signal aborts. + let signalAborted = false + fetchCallArgs[1]?.signal.addEventListener( + "abort", + () => { + signalAborted = true + }, + { once: true }, + ) + controller.abort() + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(signalAborted).toBe(true) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "done" }], + }, + ], + }), + }) + global.fetch = mockFetch as any + + await handler.completePrompt("Test prompt") + + const fetchCallArgs = mockFetch.mock.calls[0] + expect(fetchCallArgs[1]).toBeDefined() + expect(fetchCallArgs[1]?.method).toBe("POST") + }) + + it("completePrompt should work without options (backward compatible)", async () => { + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "done" }], + }, + ], + }), + }) + global.fetch = mockFetch as any + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("done") + }) }) diff --git a/src/api/providers/__tests__/openai-native.spec.ts b/src/api/providers/__tests__/openai-native.spec.ts index 2a3ec0afb2..2166123611 100644 --- a/src/api/providers/__tests__/openai-native.spec.ts +++ b/src/api/providers/__tests__/openai-native.spec.ts @@ -245,6 +245,107 @@ describe("OpenAiNativeHandler", () => { expect(result).toBe("") }) + + it("should merge incoming signal with existing controller", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "response" }], + }, + ], + }) + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: expect.any(AbortSignal) }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "response" }], + }, + ], + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("response") + }) + + it("should pass signal through to client via createOptions", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "response" }], + }, + ], + }) + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: expect.any(AbortSignal) }), + ) + }) + + it("should work without options (backward compatible)", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "response" }], + }, + ], + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("response") + }) + + it("completePrompt should pass timeoutMs through to client", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "response" }], + }, + ], + }) + + await handler.completePrompt("Test prompt", { timeoutMs: 5000 }) + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: expect.any(AbortSignal) }), + ) + }) + + it("completePrompt should merge signal and timeoutMs together", async () => { + const controller = new AbortController() + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "response" }], + }, + ], + }) + + await handler.completePrompt("Test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: expect.any(AbortSignal) }), + ) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 708a131957..6bb3f2f5c4 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -873,6 +873,45 @@ describe("OpenAiHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + { model: mockOptions.openAiModelId, messages: [{ role: "user", content: "test prompt" }] }, + { signal: controller.signal }, + ) + }) + + it("should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + { model: mockOptions.openAiModelId, messages: [{ role: "user", content: "test prompt" }] }, + { timeout: 5000 }, + ) + }) + + it("should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith( + { model: mockOptions.openAiModelId, messages: [{ role: "user", content: "test prompt" }] }, + {}, + ) + }) + + it("should merge signal and timeout together", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockCreate).toHaveBeenCalledWith( + { model: mockOptions.openAiModelId, messages: [{ role: "user", content: "test prompt" }] }, + { signal: controller.signal, timeout: 10000 }, + ) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/opencode-go.spec.ts b/src/api/providers/__tests__/opencode-go.spec.ts index 38be399c9d..ecd1ad1d86 100644 --- a/src/api/providers/__tests__/opencode-go.spec.ts +++ b/src/api/providers/__tests__/opencode-go.spec.ts @@ -394,6 +394,7 @@ describe("OpencodeGoHandler", () => { max_completion_tokens: 40_960, reasoning_effort: "medium", }), + {}, ) }) @@ -419,7 +420,7 @@ describe("OpencodeGoHandler", () => { mockCreate.mockResolvedValue({ choices: [{ message: { content: "ok" } }] }) const handler = new OpencodeGoHandler({ ...mockOptions, includeMaxTokens: true, modelMaxTokens: 4321 }) await handler.completePrompt("ping") - expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ max_completion_tokens: 4321 })) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ max_completion_tokens: 4321 }), {}) }) }) @@ -569,6 +570,7 @@ describe("OpencodeGoHandler", () => { // so the model default is used. max_tokens: 65_536, }), + undefined, ) expect(mockCreate).not.toHaveBeenCalled() }) @@ -584,7 +586,7 @@ describe("OpencodeGoHandler", () => { modelMaxTokens: 2048, }) await handler.completePrompt("ping") - expect(mockAnthropicCreate).toHaveBeenCalledWith(expect.objectContaining({ max_tokens: 2048 })) + expect(mockAnthropicCreate).toHaveBeenCalledWith(expect.objectContaining({ max_tokens: 2048 }), undefined) }) it("completePrompt rethrows non-Error values unchanged from the Anthropic path", async () => { @@ -599,6 +601,126 @@ describe("OpencodeGoHandler", () => { expect(await handler.completePrompt("ping")).toBe("") }) + it("completePrompt passes abort signal through to Anthropic client", async () => { + mockAnthropicCreate.mockResolvedValue({ content: [{ type: "text", text: "response" }] }) + const controller = new AbortController() + const handler = new OpencodeGoHandler(anthropicOptions) + await handler.completePrompt("ping", { abortSignal: controller.signal }) + expect(mockAnthropicCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + }) + }) + + it("completePrompt passes both signal and timeoutMs through to Anthropic client", async () => { + mockAnthropicCreate.mockResolvedValue({ content: [{ type: "text", text: "response" }] }) + const controller = new AbortController() + const handler = new OpencodeGoHandler(anthropicOptions) + await handler.completePrompt("ping", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockAnthropicCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + timeout: 10000, + }) + }) + + it("completePrompt passes only timeoutMs when no signal is provided", async () => { + mockAnthropicCreate.mockResolvedValue({ content: [{ type: "text", text: "response" }] }) + const handler = new OpencodeGoHandler(anthropicOptions) + await handler.completePrompt("ping", { timeoutMs: 5000 }) + expect(mockAnthropicCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + timeout: 5000, + }) + }) + + it("completePrompt works without options (backward compatible, Anthropic path)", async () => { + mockAnthropicCreate.mockResolvedValue({ content: [{ type: "text", text: "response" }] }) + const handler = new OpencodeGoHandler(anthropicOptions) + const result = await handler.completePrompt("ping") + expect(result).toBe("response") + expect(mockAnthropicCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + undefined, + ) + }) + + describe("completePrompt (OpenAI path)", () => { + const openaiOptions: ApiHandlerOptions = { + opencodeGoApiKey: "test-key", + apiModelId: "glm-5.1", // OpenAI-format model + } + + beforeEach(() => { + vitest.clearAllMocks() + }) + + it("completePrompt returns text for OpenAI path", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + const handler = new OpencodeGoHandler(openaiOptions) + expect(await handler.completePrompt("ping")).toBe("response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String), stream: false }), + {}, // empty object when no options + ) + }) + + it("completePrompt passes abort signal through to OpenAI client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const controller = new AbortController() + const handler = new OpencodeGoHandler(openaiOptions) + + await handler.completePrompt("ping", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String), stream: false }), + { signal: controller.signal }, + ) + }) + + it("completePrompt passes both signal and timeoutMs through to OpenAI client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const controller = new AbortController() + const handler = new OpencodeGoHandler(openaiOptions) + + await handler.completePrompt("ping", { abortSignal: controller.signal, timeoutMs: 10000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String), stream: false }), + { signal: controller.signal, timeout: 10000 }, + ) + }) + + it("completePrompt passes only timeoutMs when no signal is provided", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const handler = new OpencodeGoHandler(openaiOptions) + + await handler.completePrompt("ping", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String), stream: false }), + { timeout: 5000 }, + ) + }) + + it("completePrompt works without options (backward compatible, OpenAI path)", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + const handler = new OpencodeGoHandler(openaiOptions) + + const result = await handler.completePrompt("ping") + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String), stream: false }), + {}, // empty object when no options + ) + }) + }) + it("omits tools and tool_choice from the Anthropic request when no tools are provided", async () => { const handler = new OpencodeGoHandler(anthropicOptions) const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] diff --git a/src/api/providers/__tests__/openrouter.spec.ts b/src/api/providers/__tests__/openrouter.spec.ts index b21d409d0a..fc57804efb 100644 --- a/src/api/providers/__tests__/openrouter.spec.ts +++ b/src/api/providers/__tests__/openrouter.spec.ts @@ -714,5 +714,48 @@ describe("OpenRouterHandler", () => { }), ) }) + + it("should pass abort signal through to client", async () => { + const handler = new OpenRouterHandler(mockOptions) + const controller = new AbortController() + const mockResponse = { choices: [{ message: { content: "response" } }] } + const mockCreate = vitest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as any).prototype.chat = { + completions: { create: mockCreate }, + } as any + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockResponse = { choices: [{ message: { content: "response" } }] } + const mockCreate = vitest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as any).prototype.chat = { + completions: { create: mockCreate }, + } as any + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + const handler = new OpenRouterHandler(mockOptions) + const mockResponse = { choices: [{ message: { content: "response" } }] } + const mockCreate = vitest.fn().mockResolvedValue(mockResponse) + ;(OpenAI as any).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) }) diff --git a/src/api/providers/__tests__/poe.spec.ts b/src/api/providers/__tests__/poe.spec.ts index b22d42179c..21f9ba9359 100644 --- a/src/api/providers/__tests__/poe.spec.ts +++ b/src/api/providers/__tests__/poe.spec.ts @@ -309,5 +309,121 @@ describe("PoeHandler", () => { }), ) }) + + it("completePrompt should pass abort signal through to generateText", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + const controller = new AbortController() + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel, + prompt: "test prompt", + abortSignal: controller.signal, + }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel, + prompt: "test prompt", + }), + ) + }) + + it("completePrompt should merge signal and timeoutMs into combined abortSignal", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + const controller = new AbortController() + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel, + prompt: "test prompt", + abortSignal: expect.any(AbortSignal), + }), + ) + // The abortSignal should be a merged signal (not the original controller.signal) + const callArgs = mockGenerateText.mock.calls[0][0] + expect(callArgs.abortSignal).toBeDefined() + expect(callArgs.abortSignal).toBeInstanceOf(AbortSignal) + }) + + it("completePrompt should use AbortSignal.timeout when only timeoutMs is provided", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel, + prompt: "test prompt", + abortSignal: expect.any(AbortSignal), + }), + ) + const callArgs = mockGenerateText.mock.calls[0][0] + expect(callArgs.abortSignal).toBeDefined() + expect(callArgs.abortSignal).not.toBeUndefined() + }) + + it("completePrompt should prefer signal over timeoutMs when both are provided", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + const controller = new AbortController() + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + const callArgs = mockGenerateText.mock.calls[0][0] + // Should have a merged abortSignal (not the original controller.signal) + expect(callArgs.abortSignal).toBeInstanceOf(AbortSignal) + expect(callArgs.abortSignal).not.toBe(controller.signal) + }) + + it("completePrompt should clear timeout when user signal aborts", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + const controller = new AbortController() + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + const promise = handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + const callArgs = mockGenerateText.mock.calls[0][0] + expect(callArgs.abortSignal).toBeDefined() + + // Abort the user signal before resolve + controller.abort() + + await promise + // Merged signal should be aborted when user signal aborts + expect(callArgs.abortSignal.aborted).toBe(true) + }) + + it("completePrompt should handle timeoutMs=0 as no timeout", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + mockGenerateText.mockResolvedValueOnce({ text: "response" }) + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel, + prompt: "test prompt", + }), + ) + const callArgs = mockGenerateText.mock.calls[0][0] + expect(callArgs.abortSignal).toBeUndefined() + }) + + it("completePrompt should handle non-Error values in catch block", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + mockGenerateText.mockRejectedValueOnce("not an error") + + await expect(handler.completePrompt("test prompt")).rejects.toThrow() + }) }) }) diff --git a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts index 3615c0f92d..c76851dea0 100644 --- a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts +++ b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts @@ -444,4 +444,56 @@ describe("QwenCodeHandler Native Tools", () => { expect(endChunks).toHaveLength(1) }) }) + + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "This is a test response" } }], + }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("This is a test response") + }) + + it("should handle errors in completePrompt", async () => { + const errorMessage = "Qwen API error" + mockCreate.mockRejectedValueOnce(new Error(errorMessage)) + + await expect(handler.completePrompt("test prompt")).rejects.toThrow(errorMessage) + }) + + it("completePrompt should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "response" } }], + }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + }) }) diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index 5d829f2374..80fd8eed4d 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -545,12 +545,15 @@ describe("RequestyHandler", () => { expect(result).toBe("test completion") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.requestyModelId, - max_tokens: 8192, - messages: [{ role: "system", content: "test prompt" }], - temperature: 0, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.requestyModelId, + max_tokens: 8192, + messages: [{ role: "system", content: "test prompt" }], + temperature: 0, + }, + {}, + ) }) it("omits temperature for Claude Fable 5 in completePrompt", async () => { @@ -562,12 +565,15 @@ describe("RequestyHandler", () => { await handler.completePrompt("test prompt") - expect(mockCreate).toHaveBeenCalledWith({ - model: "anthropic/claude-fable-5", - max_tokens: 8192, - messages: [{ role: "system", content: "test prompt" }], - temperature: undefined, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: "anthropic/claude-fable-5", + max_tokens: 8192, + messages: [{ role: "system", content: "test prompt" }], + temperature: undefined, + }, + {}, + ) }) it("omits temperature for Claude Sonnet 5 in completePrompt", async () => { @@ -601,5 +607,34 @@ describe("RequestyHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error") }) + + it("should pass abort signal through to client", async () => { + const handler = new RequestyHandler(mockOptions) + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + }) + }) + + it("should pass timeout through to client", async () => { + const handler = new RequestyHandler(mockOptions) + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + timeout: 5000, + }) + }) + + it("should work without options (backward compatible)", async () => { + const handler = new RequestyHandler(mockOptions) + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) }) diff --git a/src/api/providers/__tests__/sambanova.spec.ts b/src/api/providers/__tests__/sambanova.spec.ts index 1455fc7f07..6410ec75a0 100644 --- a/src/api/providers/__tests__/sambanova.spec.ts +++ b/src/api/providers/__tests__/sambanova.spec.ts @@ -69,6 +69,31 @@ describe("SambaNovaHandler", () => { ) }) + it("completePrompt should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + it("createMessage should yield text content from stream", async () => { const testContent = "This is test content from SambaNova stream" diff --git a/src/api/providers/__tests__/unbound.spec.ts b/src/api/providers/__tests__/unbound.spec.ts index d8f75fe85b..0810ff9fb3 100644 --- a/src/api/providers/__tests__/unbound.spec.ts +++ b/src/api/providers/__tests__/unbound.spec.ts @@ -199,6 +199,59 @@ describe("UnboundHandler", () => { expect.objectContaining({ messages: [{ role: "system", content: "Write a haiku" }], }), + {}, ) }) + + it("completePrompt should pass abort signal through to client", async () => { + const mockCreate = (OpenAI as unknown as any)().chat.completions.create + const controller = new AbortController() + mockCreate.mockResolvedValue({ + choices: [{ message: { content: "completed text" } }], + }) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + await handler.completePrompt("Write a haiku", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + const mockCreate = (OpenAI as unknown as any)().chat.completions.create + mockCreate.mockResolvedValue({ + choices: [{ message: { content: "completed text" } }], + }) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + await handler.completePrompt("Write a haiku", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + const mockCreate = (OpenAI as unknown as any)().chat.completions.create + mockCreate.mockResolvedValue({ + choices: [{ message: { content: "completed text" } }], + }) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + const result = await handler.completePrompt("Write a haiku") + expect(result).toBe("completed text") + }) }) diff --git a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts index 4669904e2f..7605e784f0 100644 --- a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts @@ -628,6 +628,7 @@ describe("VercelAiGatewayHandler", () => { temperature: VERCEL_AI_GATEWAY_DEFAULT_TEMPERATURE, max_completion_tokens: 64000, }), + undefined, ) }) @@ -644,6 +645,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ temperature: customTemp, }), + undefined, ) }) @@ -676,10 +678,80 @@ describe("VercelAiGatewayHandler", () => { const result = await handler.completePrompt("Test") expect(result).toBe("") }) + + it("should pass abort signal through to client", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + const controller = new AbortController() + mockCreate.mockImplementation(async () => ({ + choices: [ + { + message: { role: "assistant", content: "response" }, + finish_reason: "stop", + index: 0, + }, + ], + })) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + mockCreate.mockImplementation(async () => ({ + choices: [ + { + message: { role: "assistant", content: "response" }, + finish_reason: "stop", + index: 0, + }, + ], + })) + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + mockCreate.mockImplementation(async () => ({ + choices: [ + { + message: { role: "assistant", content: "response" }, + finish_reason: "stop", + index: 0, + }, + ], + })) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) describe("temperature support", () => { it("applies temperature for supported models", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [ + { + message: { role: "assistant", content: "Test completion response" }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 8, + completion_tokens: 4, + total_tokens: 12, + }, + }) + const handler = new VercelAiGatewayHandler({ ...mockOptions, vercelAiGatewayModelId: "anthropic/claude-sonnet-4", @@ -692,6 +764,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ temperature: 0.9, }), + undefined, ) }) }) diff --git a/src/api/providers/__tests__/vertex.spec.ts b/src/api/providers/__tests__/vertex.spec.ts index a304518ca7..8f8773b6ea 100644 --- a/src/api/providers/__tests__/vertex.spec.ts +++ b/src/api/providers/__tests__/vertex.spec.ts @@ -137,6 +137,35 @@ describe("VertexHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should pass abort signal through to client via config.abortSignal", async () => { + const controller = new AbortController() + ;(handler["client"].models.generateContent as any).mockResolvedValue({ + text: "response", + }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(handler["client"].models.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: expect.any(String), + contents: [{ role: "user", parts: [{ text: "test prompt" }] }], + config: expect.objectContaining({ + abortSignal: controller.signal, + httpOptions: undefined, + temperature: 1, + }), + }), + ) + }) + + it("should work without options (backward compatible)", async () => { + ;(handler["client"].models.generateContent as any).mockResolvedValue({ + text: "response", + }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/vscode-lm.spec.ts b/src/api/providers/__tests__/vscode-lm.spec.ts index 5227c4b289..5d8feb61d8 100644 --- a/src/api/providers/__tests__/vscode-lm.spec.ts +++ b/src/api/providers/__tests__/vscode-lm.spec.ts @@ -619,7 +619,165 @@ describe("VsCodeLmHandler", () => { handler["client"] = mockLanguageModelChat const promise = handler.completePrompt("Test prompt") - await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed") + await expect(promise).rejects.toThrow("Completion failed") + }) + + it("should bridge abort signal to CancellationToken", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + return + })(), + text: (async function* () { + yield responseText + return + })(), + }) + + handler["client"] = mockLanguageModelChat + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + // Verify that tokenSource.dispose was called (via the mock) + const TokenSourceInstance = (vscode.CancellationTokenSource as any).mock.results[0].value + expect(TokenSourceInstance.dispose).toHaveBeenCalled() + }) + + it("should cancel token when signal is already aborted", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + return + })(), + text: (async function* () { + yield responseText + return + })(), + }) + + handler["client"] = mockLanguageModelChat + + const controller = new AbortController() + controller.abort() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + const TokenSourceInstance = (vscode.CancellationTokenSource as any).mock.results[0].value + expect(TokenSourceInstance.cancel).toHaveBeenCalled() + }) + + it("should work without options (backward compatible)", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + return + })(), + text: (async function* () { + yield responseText + return + })(), + }) + + handler["client"] = mockLanguageModelChat + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe(responseText) + }) + + it("should handle timeoutMs by creating a timeout-based cancellation", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + return + })(), + text: (async function* () { + yield responseText + return + })(), + }) + + handler["client"] = mockLanguageModelChat + + await handler.completePrompt("Test prompt", { timeoutMs: 5000 }) + + expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled() + }) + + it("should handle both signal and timeoutMs together", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + return + })(), + text: (async function* () { + yield responseText + return + })(), + }) + + handler["client"] = mockLanguageModelChat + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) + + expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled() + }) + + it("should handle errors in completePrompt", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("LM error")) + + handler["client"] = mockLanguageModelChat + + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("LM error") + }) + + it("should cancel token immediately when signal is already aborted", async () => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + const responseText = "Completed text" + mockLanguageModelChat.sendRequest.mockResolvedValueOnce({ + stream: (async function* () { + yield new vscode.LanguageModelTextPart(responseText) + return + })(), + text: (async function* () { + yield responseText + return + })(), + }) + + handler["client"] = mockLanguageModelChat + + const controller = new AbortController() + controller.abort() // Abort before calling completePrompt + + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled() }) }) }) diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index 78f091b8e3..042fb3d0f5 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -255,6 +255,57 @@ describe("XAIHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) }) + it("completePrompt should pass abort signal through to client", async () => { + const controller = new AbortController() + mockResponsesCreate.mockResolvedValueOnce({ output_text: "response" }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockResponsesCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + }) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + mockResponsesCreate.mockResolvedValueOnce({ output_text: "response" }) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + undefined, + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + const controller = new AbortController() + mockResponsesCreate.mockResolvedValueOnce({ output_text: "response" }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + expect(mockResponsesCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + signal: controller.signal, + timeout: 5000, + }) + }) + + it("completePrompt should pass only timeoutMs when no signal provided", async () => { + mockResponsesCreate.mockResolvedValueOnce({ output_text: "response" }) + + await handler.completePrompt("test prompt", { timeoutMs: 3000 }) + expect(mockResponsesCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), { + timeout: 3000, + }) + }) + + it("completePrompt should pass timeout when timeoutMs=0 (defined check)", async () => { + mockResponsesCreate.mockResolvedValueOnce({ output_text: "response" }) + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + { timeout: 0 }, // !== undefined check means 0 is passed through + ) + }) + it("should include reasoning_effort for mini models", async () => { const miniModelHandler = new XAIHandler({ apiModelId: "grok-3-mini", diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index 66266a2fee..505408ff95 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -427,6 +427,31 @@ describe("ZAiHandler", () => { ) }) + it("completePrompt should pass abort signal through to client", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("completePrompt should pass timeout through to client", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("completePrompt should work without options (backward compatible)", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) + it("createMessage should yield text content from stream", async () => { const testContent = "This is test content from Z AI stream" diff --git a/src/api/providers/__tests__/zoo-gateway.spec.ts b/src/api/providers/__tests__/zoo-gateway.spec.ts index e0c060db3b..92ea81a2df 100644 --- a/src/api/providers/__tests__/zoo-gateway.spec.ts +++ b/src/api/providers/__tests__/zoo-gateway.spec.ts @@ -445,6 +445,7 @@ describe("ZooGatewayHandler", () => { temperature: ZOO_GATEWAY_DEFAULT_TEMPERATURE, max_completion_tokens: 64000, }), + {}, ) }) @@ -467,6 +468,43 @@ describe("ZooGatewayHandler", () => { await expect(handler.completePrompt("Test")).resolves.toBe("") }) + + it("should pass abort signal through to client", async () => { + const handler = new ZooGatewayHandler(mockOptions) + const controller = new AbortController() + mockCreate.mockImplementation(async () => ({ + choices: [{ message: { role: "assistant", content: "response" } }], + })) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeout through to client", async () => { + const handler = new ZooGatewayHandler(mockOptions) + mockCreate.mockImplementation(async () => ({ + choices: [{ message: { role: "assistant", content: "response" } }], + })) + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: expect.any(String) }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + const handler = new ZooGatewayHandler(mockOptions) + mockCreate.mockImplementation(async () => ({ + choices: [{ message: { role: "assistant", content: "response" } }], + })) + + const result = await handler.completePrompt("test prompt") + expect(result).toBe("response") + }) }) describe("classifyGatewayApiError", () => { diff --git a/src/api/providers/anthropic-vertex.ts b/src/api/providers/anthropic-vertex.ts index 9089562f4f..7dc8f0e4aa 100644 --- a/src/api/providers/anthropic-vertex.ts +++ b/src/api/providers/anthropic-vertex.ts @@ -25,7 +25,7 @@ import { import { BaseProvider } from "./base-provider" import { parseVertexJsonCredentials } from "./utils/vertex-credentials" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" // https://docs.anthropic.com/en/api/claude-on-vertex-ai export class AnthropicVertexHandler extends BaseProvider implements SingleCompletionHandler { @@ -270,7 +270,7 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple } } - async completePrompt(prompt: string) { + async completePrompt(prompt: string, options?: CompletePromptOptions) { try { const { id, @@ -296,7 +296,18 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple stream: false, } as Anthropic.Messages.MessageCreateParamsNonStreaming - const response = await this.client.messages.create(params) + const requestOptions: Anthropic.RequestOptions = {} + if (options?.abortSignal) { + requestOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + requestOptions.timeout = options.timeoutMs + } + + const response = await this.client.messages.create( + params, + Object.keys(requestOptions).length > 0 ? requestOptions : undefined, + ) const content = response.content[0] if (content.type === "text") { diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index f9e2ee7d2b..7434cb796e 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -22,7 +22,7 @@ import { getAnthropicProviderReasoning } from "../transform/reasoning" import { handleProviderError } from "./utils/error-handler" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { calculateApiCostAnthropic } from "../../shared/cost" import { convertOpenAIToolsToAnthropic, @@ -400,19 +400,31 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } } - async completePrompt(prompt: string) { + async completePrompt(prompt: string, options?: CompletePromptOptions) { const { id: model, temperature } = this.getModel() let message try { - message = await this.client.messages.create({ - model, - max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS, - thinking: undefined, - temperature, - messages: [{ role: "user", content: prompt }], - stream: false, - }) + // Build request options with both abortSignal and timeout handling + const requestOptions: Anthropic.RequestOptions = {} + if (options?.abortSignal) { + requestOptions.signal = options.abortSignal + } + if (options?.timeoutMs) { + requestOptions.timeout = options.timeoutMs + } + + message = await this.client.messages.create( + { + model, + max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS, + thinking: undefined, + temperature, + messages: [{ role: "user", content: prompt }], + stream: false, + }, + Object.keys(requestOptions).length > 0 ? requestOptions : undefined, + ) } catch (error) { TelemetryService.instance.captureException( new ApiProviderError( diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index b6094f9cc4..314a39bba5 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -8,7 +8,7 @@ import { TagMatcher } from "../../utils/tag-matcher" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import { handleOpenAIError } from "./utils/openai-error-handler" @@ -212,7 +212,7 @@ export abstract class BaseOpenAiCompatibleProvider } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: modelId, info: modelInfo } = this.getModel() const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = { @@ -226,7 +226,16 @@ export abstract class BaseOpenAiCompatibleProvider } try { - const response = await this.client.chat.completions.create(params) + // Build request options with abortSignal and/or timeout + const requestOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + requestOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + requestOptions.timeout = options.timeoutMs + } + + const response = await this.client.chat.completions.create(params, requestOptions || undefined) // Check for provider-specific error responses (e.g., MiniMax base_resp) const responseAny = response as any diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 8c2b5ace68..c2092a2284 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -44,7 +44,7 @@ import { convertToBedrockConverseMessages as sharedConverter } from "../transfor import { getModelParams } from "../transform/model-params" import { shouldUseReasoningBudget } from "../../shared/api" import { normalizeToolSchema } from "../../utils/json-schema" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" /************************************************************************************ * @@ -798,7 +798,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { try { const modelConfig = this.getModel() @@ -840,7 +840,41 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } const command = new ConverseCommand(payload) - const response = await this.client.send(command) + + // Build request options with abortSignal and/or timeoutMs + const sendOptions: { abortSignal?: AbortSignal } | undefined = (() => { + let signal: AbortSignal | undefined = options?.abortSignal + if (options?.timeoutMs !== undefined) { + if (signal) { + // When both are provided, create a merged signal that aborts when either fires + const controller = new AbortController() + if (signal.aborted) { + controller.abort() + } else if (options.timeoutMs > 0) { + const timeoutId = setTimeout(() => controller.abort(), options.timeoutMs) + signal.addEventListener( + "abort", + () => { + clearTimeout(timeoutId) + controller.abort() + }, + { once: true }, + ) + } else { + // timeoutMs is 0, abort immediately + controller.abort() + } + signal = controller.signal + } else if (options.timeoutMs !== undefined) { + if (options.timeoutMs > 0) { + signal = AbortSignal.timeout(options.timeoutMs) + } + } + } + return signal ? { abortSignal: signal } : undefined + })() + + const response = await this.client.send(command, sendOptions) if ( response?.output?.message?.content && diff --git a/src/api/providers/fake-ai.ts b/src/api/providers/fake-ai.ts index e69a1c84e8..b2ea1c2b60 100644 --- a/src/api/providers/fake-ai.ts +++ b/src/api/providers/fake-ai.ts @@ -2,7 +2,12 @@ import { Anthropic } from "@anthropic-ai/sdk" import type { ModelInfo } from "@roo-code/types" -import type { ApiHandler, SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { + ApiHandler, + SingleCompletionHandler, + ApiHandlerCreateMessageMetadata, + CompletePromptOptions, +} from "../index" import type { ApiHandlerOptions } from "../../shared/api" import { ApiStream } from "../transform/stream" @@ -28,7 +33,7 @@ interface FakeAI { ): ApiStream getModel(): { id: string; info: ModelInfo } countTokens(content: Array): Promise - completePrompt(prompt: string): Promise + completePrompt(prompt: string, options?: CompletePromptOptions): Promise } /** @@ -75,7 +80,7 @@ export class FakeAIHandler implements ApiHandler, SingleCompletionHandler { return this.ai.countTokens(content) } - completePrompt(prompt: string): Promise { - return this.ai.completePrompt(prompt) + completePrompt(prompt: string, options?: CompletePromptOptions): Promise { + return (this.ai as any).completePrompt(prompt, options) } } diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 6c8168cae2..8e9dc1fc4c 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -23,7 +23,7 @@ import { t } from "i18next" import type { ApiStream, GroundingSource } from "../transform/stream" import { getModelParams } from "../transform/model-params" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { BaseProvider } from "./base-provider" import { parseVertexJsonCredentials } from "./utils/vertex-credentials" @@ -576,7 +576,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return citationLinks.join(", ") } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: model, info } = this.getModel() try { @@ -584,14 +584,25 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl const temperatureConfig: number | undefined = supportsTemperature ? (this.options.modelTemperature ?? info.defaultTemperature ?? 1) : info.defaultTemperature + const httpOpts: Record = {} + if (options?.timeoutMs !== undefined) { + httpOpts.timeout = options.timeoutMs + } + if (this.options.googleGeminiBaseUrl) { + httpOpts.baseUrl = this.options.googleGeminiBaseUrl + } const promptConfig: GenerateContentConfig = { - httpOptions: this.options.googleGeminiBaseUrl - ? { baseUrl: this.options.googleGeminiBaseUrl } - : undefined, + httpOptions: Object.keys(httpOpts).length > 0 ? httpOpts : undefined, temperature: temperatureConfig, } + // Pass abortSignal directly to config.abortSignal (not httpOptions.signal) + // as @google/genai expects request cancellation on this property + if (options?.abortSignal) { + promptConfig.abortSignal = options.abortSignal + } + const request = { model, contents: [{ role: "user", parts: [{ text: prompt }] }], diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index 3f9b3732e5..a0885fe4a0 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -12,7 +12,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format" import { GEMINI_THOUGHT_SIGNATURE_BYPASS } from "../transform/gemini-format" import { sanitizeOpenAiCallId } from "../../utils/tool-id" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { RouterProvider } from "./router-provider" import { extractReasoningFromDelta } from "./utils/extract-reasoning" @@ -311,7 +311,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: modelId, info } = await this.fetchModel() // Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens @@ -334,7 +334,16 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa requestOptions.max_tokens = info.maxTokens } - const response = await this.client.chat.completions.create(requestOptions) + // Build request options with abortSignal and/or timeout + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + createOptions.timeout = options.timeoutMs + } + + const response = await this.client.chat.completions.create(requestOptions, createOptions || undefined) return response.choices[0]?.message.content || "" } catch (error) { if (error instanceof Error) { diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index 4bc9497719..e6e6ee55c5 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -13,7 +13,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { getModelsFromCache } from "./fetchers/modelCache" import { handleOpenAIError } from "./utils/openai-error-handler" @@ -187,7 +187,7 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { try { // Create params object with optional draft model const params: any = { @@ -202,9 +202,18 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan params.draft_model = this.options.lmStudioDraftModelId } + // Build request options with abortSignal and/or timeout + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + createOptions.timeout = options.timeoutMs + } + let response try { - response = await this.client.chat.completions.create(params) + response = await this.client.chat.completions.create(params, createOptions || undefined) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index 93aa7ea8f1..fc366594fc 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -12,7 +12,7 @@ import { getModelParams } from "../transform/model-params" import { mergeEnvironmentDetailsForMiniMax } from "../transform/minimax-format" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { calculateApiCostAnthropic } from "../../shared/cost" import { convertOpenAIToolsToAnthropic } from "../../core/prompts/tools/native-tools/converters" @@ -289,16 +289,28 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand } } - async completePrompt(prompt: string) { + async completePrompt(prompt: string, options?: CompletePromptOptions) { const { id: model, temperature } = this.getModel() - const message = await this.client.messages.create({ - model, - max_tokens: 16_384, - temperature: temperature ?? 1.0, - messages: [{ role: "user", content: prompt }], - stream: false, - }) + // Build request options with abortSignal and/or timeout handling + const requestOptions: Anthropic.RequestOptions = {} + if (options?.abortSignal) { + requestOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + requestOptions.timeout = options.timeoutMs + } + + const message = await this.client.messages.create( + { + model, + max_tokens: 16_384, + temperature: temperature ?? 1.0, + messages: [{ role: "user", content: prompt }], + stream: false, + }, + Object.keys(requestOptions).length > 0 ? requestOptions : undefined, + ) const content = message.content.find(({ type }) => type === "text") return content?.type === "text" ? content.text : "" diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index e0e19298f4..700da6de75 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -18,7 +18,7 @@ import { ApiStream } from "../transform/stream" import { handleProviderError } from "./utils/error-handler" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" // Type helper to handle thinking chunks from Mistral API // The SDK includes ThinkChunk but TypeScript has trouble with the discriminated union @@ -193,15 +193,27 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand return { id, info, maxTokens, temperature } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: model, temperature } = this.getModel() try { - const response = await this.client.chat.complete({ - model, - messages: [{ role: "user", content: prompt }], - temperature, - }) + // Build Mistral SDK RequestOptions + const requestOptions: Parameters[1] = {} + if (options?.abortSignal) { + requestOptions.fetchOptions = { signal: options.abortSignal } + } + if (options?.timeoutMs !== undefined) { + requestOptions.timeoutMs = options.timeoutMs + } + + const response = await this.client.chat.complete( + { + model, + messages: [{ role: "user", content: prompt }], + temperature, + }, + Object.keys(requestOptions).length > 0 ? requestOptions : undefined, + ) const content = response.choices?.[0]?.message.content diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 4574c184f2..58faf295d2 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -7,7 +7,7 @@ import { BaseProvider } from "./base-provider" import type { ApiHandlerOptions } from "../../shared/api" import { getOllamaModels } from "./fetchers/ollama" import { TagMatcher } from "../../utils/tag-matcher" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" interface OllamaChatOptions { temperature: number @@ -347,7 +347,8 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, _options?: CompletePromptOptions): Promise { + // Ollama native client doesn't support abort signals at all — accept param but ignore try { const client = this.ensureClient() const { id: modelId } = await this.fetchModel() diff --git a/src/api/providers/openai-codex.ts b/src/api/providers/openai-codex.ts index 2b2a599c0f..784dc17c45 100644 --- a/src/api/providers/openai-codex.ts +++ b/src/api/providers/openai-codex.ts @@ -21,7 +21,7 @@ import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { isMcpTool } from "../../utils/mcp-name" import { sanitizeOpenAiCallId } from "../../utils/tool-id" import { openAiCodexOAuthManager } from "../../integrations/openai-codex/oauth" @@ -1154,8 +1154,39 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion return this.lastResponseId } - async completePrompt(prompt: string): Promise { - this.abortController = new AbortController() + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { + // Build a request-local abort controller with timeout support (don't mutate this.abortController) + let localAbortController: AbortController | undefined + let timeoutId: ReturnType | undefined + + if (options?.timeoutMs !== undefined || options?.abortSignal) { + localAbortController = new AbortController() + + // Handle timeout first + if (options.timeoutMs !== undefined) { + if (options.timeoutMs > 0) { + timeoutId = setTimeout(() => localAbortController?.abort(), options.timeoutMs) + } else { + // timeoutMs is 0 or negative, abort immediately + localAbortController.abort() + } + } + + // Merge with incoming abortSignal if provided using AbortSignal.any + if (options.abortSignal) { + const mergedSignal = AbortSignal.any([localAbortController.signal, options.abortSignal]) + mergedSignal.addEventListener( + "abort", + () => { + localAbortController?.abort() + clearTimeout(timeoutId) + }, + { once: true }, + ) + } + } + + const requestSignal = localAbortController?.signal ?? new AbortController().signal try { const model = this.getModel() @@ -1216,7 +1247,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion method: "POST", headers, body: JSON.stringify(requestBody), - signal: this.abortController.signal, + signal: requestSignal, }) if (!response.ok) { diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts index d129e72452..b221e211f2 100644 --- a/src/api/providers/openai-compatible.ts +++ b/src/api/providers/openai-compatible.ts @@ -17,7 +17,7 @@ import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" /** * Configuration options for creating an OpenAI-compatible provider. @@ -197,16 +197,58 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si /** * Complete a prompt using the AI SDK generateText. */ - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const languageModel = this.getLanguageModel() - const { text } = await generateText({ + const generateOptions: Parameters[0] & { abortSignal?: AbortSignal } = { model: languageModel, prompt, maxOutputTokens: this.getMaxOutputTokens(), temperature: this.config.temperature ?? 0, - }) + } - return text + // Merge abortSignal and timeoutMs into a single abortSignal + let timeoutId: ReturnType | undefined + if (options?.abortSignal && options?.timeoutMs !== undefined) { + // When both are provided, create a merged signal that aborts when either fires + const controller = new AbortController() + if (options.abortSignal.aborted) { + controller.abort() + } else if (options.timeoutMs > 0) { + timeoutId = setTimeout(() => controller.abort(), options.timeoutMs) + options.abortSignal.addEventListener( + "abort", + () => { + clearTimeout(timeoutId) + controller.abort() + }, + { once: true }, + ) + } else { + // timeoutMs is 0 or negative, abort immediately + controller.abort() + } + + generateOptions.abortSignal = controller.signal + } else if (options?.abortSignal) { + generateOptions.abortSignal = options.abortSignal + } else if (options?.timeoutMs !== undefined) { + if (options.timeoutMs > 0) { + generateOptions.abortSignal = AbortSignal.timeout(options.timeoutMs) + } else if (options.timeoutMs === 0) { + const controller = new AbortController() + controller.abort() + generateOptions.abortSignal = controller.signal + } + } + + try { + const { text } = await generateText(generateOptions) + return text + } finally { + if (timeoutId !== undefined) { + clearTimeout(timeoutId) + } + } } } diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index 8f86f46751..bcd106ce90 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -26,7 +26,7 @@ import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { isMcpTool } from "../../utils/mcp-name" import { sanitizeOpenAiCallId } from "../../utils/tool-id" @@ -1485,9 +1485,19 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio return this.lastResponseId } - async completePrompt(prompt: string): Promise { - // Create AbortController for cancellation + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { + // Merge incoming abortSignal with existing class-level controller using AbortSignal.any + const baseSignal = this.abortController?.signal ?? new AbortController().signal + const mergedSignal = options?.abortSignal ? AbortSignal.any([baseSignal, options.abortSignal]) : baseSignal + + // Create AbortController for cancellation (keep for cleanup tracking) this.abortController = new AbortController() + // Link the merged signal to our abort controller + if (mergedSignal.aborted) { + this.abortController.abort() + } else { + mergedSignal.addEventListener("abort", () => this.abortController?.abort(), { once: true }) + } try { const model = this.getModel() @@ -1549,7 +1559,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio // Make the non-streaming request const response = await (this.client as any).responses.create(requestBody, { - signal: this.abortController.signal, + signal: mergedSignal, }) // Extract text from the response diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index c8f17dac3e..217f02134b 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -296,7 +296,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl return { id, info, ...params } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: import("../../api").CompletePromptOptions): Promise { try { const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) const model = this.getModel() @@ -310,11 +310,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl // Add max_tokens if needed this.addMaxTokensIfNeeded(requestOptions, modelInfo) + // Build request options with abortSignal and/or timeout + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + createOptions.timeout = options.timeoutMs + } + let response try { response = await this.client.chat.completions.create( requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH, ...createOptions } : createOptions, ) } catch (error) { throw handleOpenAIError(error, this.providerName) diff --git a/src/api/providers/opencode-go.ts b/src/api/providers/opencode-go.ts index 27d8ab3f7e..33a84cb124 100644 --- a/src/api/providers/opencode-go.ts +++ b/src/api/providers/opencode-go.ts @@ -18,7 +18,7 @@ import { convertToR1Format } from "../transform/r1-format" import { filterNonAnthropicBlocks } from "../transform/anthropic-filter" import { getModelParams } from "../transform/model-params" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { RouterProvider } from "./router-provider" import { extractReasoningFromDelta } from "./utils/extract-reasoning" import { DEFAULT_HEADERS } from "./constants" @@ -485,25 +485,37 @@ export class OpencodeGoHandler extends RouterProvider implements SingleCompletio * @returns The model's reply text, or an empty string if no content is returned. * @throws Error with an Opencode Go-specific prefix if the request fails. */ - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: modelId, format, temperature, reasoningEffort, maxTokens } = await this.resolveModel() if (format === "anthropic") { try { - const message = await this.anthropicClient.messages.create({ - model: modelId, - // Honour the same includeMaxTokens/modelMaxTokens override - // logic as the streaming path so non-streaming completions - // respect the user's max-output slider instead of always - // falling back to the model default. - max_tokens: - this.options.includeMaxTokens === true - ? this.options.modelMaxTokens || maxTokens || 16_384 - : (maxTokens ?? 16_384), - temperature: this.supportsTemperature(modelId) ? (temperature ?? 1.0) : undefined, - messages: [{ role: "user", content: prompt }], - stream: false, - }) + // Build request options with abortSignal and/or timeout handling + const requestOptions: Anthropic.RequestOptions = {} + if (options?.abortSignal) { + requestOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + requestOptions.timeout = options.timeoutMs + } + + const message = await this.anthropicClient.messages.create( + { + model: modelId, + // Honour the same includeMaxTokens/modelMaxTokens override + // logic as the streaming path so non-streaming completions + // respect the user's max-output slider instead of always + // falling back to the model default. + max_tokens: + this.options.includeMaxTokens === true + ? this.options.modelMaxTokens || maxTokens || 16_384 + : (maxTokens ?? 16_384), + temperature: this.supportsTemperature(modelId) ? (temperature ?? 1.0) : undefined, + messages: [{ role: "user", content: prompt }], + stream: false, + }, + Object.keys(requestOptions).length > 0 ? requestOptions : undefined, + ) const content = message.content.find(({ type }) => type === "text") return content?.type === "text" ? content.text : "" @@ -534,7 +546,16 @@ export class OpencodeGoHandler extends RouterProvider implements SingleCompletio reasoningEffort as OpenAI.Chat.ChatCompletionCreateParams["reasoning_effort"] } - const response = await this.client.chat.completions.create(requestOptions) + // Build request options with abortSignal and/or timeout for OpenAI path + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + createOptions.timeout = options.timeoutMs + } + + const response = await this.client.chat.completions.create(requestOptions, createOptions || undefined) return response.choices[0]?.message.content || "" } catch (error) { if (error instanceof Error) { diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 1ac9c465b6..510559e8c7 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -35,7 +35,7 @@ import { getModelEndpoints } from "./fetchers/modelEndpointCache" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" -import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index" +import type { ApiHandlerCreateMessageMetadata, CompletePromptOptions, SingleCompletionHandler } from "../index" import { handleOpenAIError } from "./utils/openai-error-handler" import { generateImageWithProvider, ImageGenerationResult } from "./utils/image-generation" import { applyRouterToolPreferences } from "./utils/router-tool-preferences" @@ -574,7 +574,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH return { id, info, topP: isDeepSeekR1 ? 0.95 : undefined, ...params } } - async completePrompt(prompt: string) { + async completePrompt(prompt: string, options?: CompletePromptOptions) { const { id: modelId, maxTokens, temperature, reasoning } = await this.fetchModel() const completionParams: OpenRouterChatCompletionParams = { @@ -596,9 +596,14 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } // Add Anthropic beta header for fine-grained tool streaming when using Anthropic models - const requestOptions = modelId.startsWith("anthropic/") - ? { headers: { "x-anthropic-beta": "fine-grained-tool-streaming-2025-05-14" } } - : undefined + // Merge signal + timeout with existing headers + const requestOptions: OpenAI.RequestOptions = { + ...(modelId.startsWith("anthropic/") + ? { headers: { "x-anthropic-beta": "fine-grained-tool-streaming-2025-05-14" } } + : undefined), + ...(options?.abortSignal && { signal: options.abortSignal }), + ...(options?.timeoutMs !== undefined && { timeout: options.timeoutMs }), + } let response diff --git a/src/api/providers/poe.ts b/src/api/providers/poe.ts index 9513a1445d..ca8ccdbe67 100644 --- a/src/api/providers/poe.ts +++ b/src/api/providers/poe.ts @@ -18,7 +18,7 @@ import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } import { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { getModelsFromCache } from "./fetchers/modelCache" const DEFAULT_THINKING_BUDGET = 8192 @@ -138,18 +138,41 @@ export class PoeHandler extends BaseProvider implements SingleCompletionHandler } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id } = this.getModel() + let timeoutId: ReturnType | undefined try { - const { text } = await generateText({ + const generateOptions: Parameters[0] & { abortSignal?: AbortSignal } = { model: this.poe(id), prompt, - }) + } + + // Merge abortSignal and timeoutMs into a single abortSignal + if (options?.abortSignal && options?.timeoutMs && options.timeoutMs > 0) { + const controller = new AbortController() + if (options.abortSignal.aborted) { + controller.abort() + } else { + timeoutId = setTimeout(() => controller.abort(), options.timeoutMs) + options.abortSignal.addEventListener("abort", () => controller.abort(), { once: true }) + } + generateOptions.abortSignal = controller.signal + } else if (options?.abortSignal) { + generateOptions.abortSignal = options.abortSignal + } else if (options?.timeoutMs && options.timeoutMs > 0) { + generateOptions.abortSignal = AbortSignal.timeout(options.timeoutMs) + } + + const { text } = await generateText(generateOptions) return text } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) TelemetryService.instance.captureException(new ApiProviderError(errorMessage, "poe", id, "completePrompt")) throw new Error(`Poe completion error: ${errorMessage}`) + } finally { + if (timeoutId) { + clearTimeout(timeoutId) + } } } } diff --git a/src/api/providers/qwen-code.ts b/src/api/providers/qwen-code.ts index cdf0f88e4b..f054a3266d 100644 --- a/src/api/providers/qwen-code.ts +++ b/src/api/providers/qwen-code.ts @@ -15,7 +15,7 @@ import { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" import { extractReasoningFromDelta } from "./utils/extract-reasoning" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" const QWEN_OAUTH_BASE_URL = "https://chat.qwen.ai" const QWEN_OAUTH_TOKEN_ENDPOINT = `${QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token` @@ -327,7 +327,7 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan return { id, info } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { await this.ensureAuthenticated() const client = this.ensureClient() const model = this.getModel() @@ -338,7 +338,19 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan max_completion_tokens: model.info.maxTokens, } - const response = await this.callApiWithRetry(() => client.chat.completions.create(requestOptions)) + const fetchOptions: Record = {} + + if (options?.abortSignal) { + fetchOptions.signal = options.abortSignal + } + + if (options?.timeoutMs !== undefined) { + fetchOptions.timeout = options.timeoutMs + } + + const response = await this.callApiWithRetry(() => + client.chat.completions.create(requestOptions, fetchOptions as any), + ) return response.choices[0]?.message.content || "" } diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index c490227d44..c196942aff 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -14,7 +14,7 @@ import { AnthropicProviderReasoningParams, getAnthropicProviderReasoning } from import { DEFAULT_HEADERS } from "./constants" import { getModels } from "./fetchers/modelCache" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { toRequestyServiceUrl } from "../../shared/utils/requesty" import { handleOpenAIError } from "./utils/openai-error-handler" import { applyRouterToolPreferences } from "./utils/router-tool-preferences" @@ -204,7 +204,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel() const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }] @@ -216,9 +216,17 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan temperature: temperature, } + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (typeof options?.timeoutMs === "number" && options.timeoutMs > 0) { + createOptions.timeout = options.timeoutMs + } + let response: OpenAI.Chat.ChatCompletion try { - response = await this.client.chat.completions.create(completionParams) + response = await this.client.chat.completions.create(completionParams, createOptions || undefined) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 0ec7a2466f..e76cac6639 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -14,7 +14,7 @@ import { OpenAiReasoningParams } from "../transform/reasoning" import { DEFAULT_HEADERS } from "./constants" import { getModels } from "./fetchers/modelCache" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { handleOpenAIError } from "./utils/openai-error-handler" import { applyRouterToolPreferences } from "./utils/router-tool-preferences" import { extractReasoningFromDelta } from "./utils/extract-reasoning" @@ -192,7 +192,7 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel() const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }] @@ -203,10 +203,18 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand messages: openAiMessages, temperature: temperature, } + // Build request options with abortSignal and/or timeout + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + createOptions.timeout = options.timeoutMs + } let response: OpenAI.Chat.ChatCompletion try { - response = await this.client.chat.completions.create(completionParams) + response = await this.client.chat.completions.create(completionParams, createOptions || undefined) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/vercel-ai-gateway.ts b/src/api/providers/vercel-ai-gateway.ts index 0c7bd1d485..7ede41edf8 100644 --- a/src/api/providers/vercel-ai-gateway.ts +++ b/src/api/providers/vercel-ai-gateway.ts @@ -14,7 +14,7 @@ import { ApiStream } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" import { addCacheBreakpoints } from "../transform/caching/vercel-ai-gateway" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { RouterProvider } from "./router-provider" // Extend OpenAI's CompletionUsage to include Vercel AI Gateway specific fields @@ -117,7 +117,7 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const { id: modelId, info } = await this.fetchModel() try { @@ -132,8 +132,19 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp } requestOptions.max_completion_tokens = info.maxTokens + // Build request options with abortSignal and/or timeout + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + createOptions.timeout = options.timeoutMs + } - const response = await this.client.chat.completions.create(requestOptions) + const response = await this.client.chat.completions.create( + requestOptions, + Object.keys(createOptions).length > 0 ? createOptions : undefined, + ) return response.choices[0]?.message.content || "" } catch (error) { if (error instanceof Error) { diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index 5b9064363f..708dd5de18 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -12,7 +12,7 @@ import { ApiStream } from "../transform/stream" import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../transform/vscode-lm-format" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" /** * Converts OpenAI-format tools to VSCode Language Model tools. @@ -458,7 +458,6 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan } } catch (error) { console.error("Roo Code : Failed to process tool call:", error) - // Continue processing other chunks even if one fails continue } } else { @@ -489,19 +488,21 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan name: error.name, }) - // Return original error if it's already an Error instance throw error } else if (typeof error === "object" && error !== null) { - // Handle error-like objects const errorDetails = JSON.stringify(error, null, 2) console.error("Roo Code : Stream error object:", errorDetails) throw new Error(`Roo Code : Response stream error: ${errorDetails}`) } else { - // Fallback for unknown error types const errorMessage = String(error) console.error("Roo Code : Unknown stream error:", errorMessage) throw new Error(`Roo Code : Response stream error: ${errorMessage}`) } + } finally { + if (this.currentRequestCancellation) { + this.currentRequestCancellation.dispose() + this.currentRequestCancellation = null + } } } @@ -582,13 +583,31 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan return this.getModel().info.contextWindow } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { + const client = await this.getClient() + + // Bridge external AbortSignal to VSCode CancellationToken + const tokenSource = new vscode.CancellationTokenSource() + + // Handle timeoutMs by creating a timeout-based cancellation + let timeoutTimeout: ReturnType | undefined + if (options?.timeoutMs && options.timeoutMs > 0) { + timeoutTimeout = setTimeout(() => tokenSource.cancel(), options.timeoutMs) + } + + if (options?.abortSignal) { + if (options.abortSignal.aborted) { + tokenSource.cancel() + } else { + options.abortSignal.addEventListener("abort", () => tokenSource.cancel(), { once: true }) + } + } + try { - const client = await this.getClient() const response = await client.sendRequest( [vscode.LanguageModelChatMessage.User(prompt)], {}, - new vscode.CancellationTokenSource().token, + tokenSource.token, ) let result = "" for await (const chunk of response.stream) { @@ -597,11 +616,16 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan } } return result - } catch (error) { + } catch (error: any) { if (error instanceof Error) { throw new Error(`VSCode LM completion error: ${error.message}`) } throw error + } finally { + if (timeoutTimeout) { + clearTimeout(timeoutTimeout) + } + tokenSource.dispose() } } } diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index e5c0ba0a81..35519d96e2 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -13,7 +13,7 @@ import { getModelParams } from "../transform/model-params" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { handleOpenAIError } from "./utils/openai-error-handler" import { isMcpTool } from "../../utils/mcp-name" @@ -142,15 +142,27 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler yield* processResponsesApiStream(stream, normalizeUsage) } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { const model = this.getModel() try { - const response = await this.client.responses.create({ - model: model.id, - input: [{ role: "user", content: [{ type: "input_text", text: prompt }] }], - store: false, - }) + // Build request options with abortSignal and/or timeout handling + const requestOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + requestOptions.signal = options.abortSignal + } + if (options?.timeoutMs !== undefined) { + requestOptions.timeout = options.timeoutMs + } + + const response = await this.client.responses.create( + { + model: model.id, + input: [{ role: "user", content: [{ type: "input_text", text: prompt }] }], + store: false, + }, + Object.keys(requestOptions).length > 0 ? requestOptions : undefined, + ) // output_text is a convenience field on the Responses API response return response.output_text || "" diff --git a/src/api/providers/zoo-gateway.ts b/src/api/providers/zoo-gateway.ts index 4724464ff3..7326d0e833 100644 --- a/src/api/providers/zoo-gateway.ts +++ b/src/api/providers/zoo-gateway.ts @@ -18,7 +18,7 @@ import { ApiStream } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" import { addCacheBreakpoints } from "../transform/caching/vercel-ai-gateway" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata, CompletePromptOptions } from "../index" import { RouterProvider } from "./router-provider" function getApiErrorStatus(error: unknown): number | undefined { @@ -276,7 +276,7 @@ export class ZooGatewayHandler extends RouterProvider implements SingleCompletio } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { this.ensureAuthenticated() const { id: modelId, info } = await this.fetchModel() @@ -294,7 +294,16 @@ export class ZooGatewayHandler extends RouterProvider implements SingleCompletio requestOptions.max_completion_tokens = info.maxTokens - const response = await this.client.chat.completions.create(requestOptions) + // Build request options with abortSignal and/or timeout + const createOptions: OpenAI.RequestOptions = {} + if (options?.abortSignal) { + createOptions.signal = options.abortSignal + } + if ((options as any)?.timeoutMs !== undefined && (options as any).timeoutMs > 0) { + createOptions.timeout = (options as any).timeoutMs + } + + const response = await this.client.chat.completions.create(requestOptions, createOptions || undefined) return response.choices[0]?.message.content || "" } catch (error) { try { diff --git a/src/core/task/__tests__/Task.spec.ts b/src/core/task/__tests__/Task.spec.ts index 27ba5ce8ff..fbc01cb78f 100644 --- a/src/core/task/__tests__/Task.spec.ts +++ b/src/core/task/__tests__/Task.spec.ts @@ -1812,7 +1812,7 @@ describe("Cline", () => { expect(summarizeConversation).toHaveBeenCalled() const [options] = vi.mocked(summarizeConversation).mock.calls.at(-1)! - expect(options.metadata?.abortSignal).toBeInstanceOf(AbortSignal) + expect(options.metadata?.abortSignal).toBe(task.currentRequestAbortController!.signal) }) it("should omit abortSignal from condenseContext metadata when no current request exists", async () => { @@ -1899,7 +1899,7 @@ describe("Cline", () => { const [, , metadata] = createMessageSpy.mock.calls[0]! expect(metadata).toBeDefined() - expect(metadata!.abortSignal).toBeInstanceOf(AbortSignal) + expect(metadata!.abortSignal).toBe(task.currentRequestAbortController!.signal) }) it("should invoke abort on currentRequestAbortController during first-chunk wait", async () => { @@ -2136,7 +2136,7 @@ describe("Cline", () => { const [, , metadata] = createMessageSpy.mock.calls[0]! expect(metadata).toBeDefined() expect("abortSignal" in metadata!).toBe(true) - expect(metadata!.abortSignal).toBeInstanceOf(AbortSignal) + expect(metadata!.abortSignal).toBe(task.currentRequestAbortController!.signal) }) it("should keep createMessage abortSignal metadata unaborted before cancellation", async () => { @@ -2197,11 +2197,100 @@ describe("Cline", () => { await iterator.next() const [, , metadata] = createMessageSpy.mock.calls[0]! - expect(metadata?.abortSignal).toBeInstanceOf(AbortSignal) + expect(metadata?.abortSignal).toBe(task.currentRequestAbortController!.signal) expect(metadata?.abortSignal?.aborted).toBe(false) }) }) + it("should create a fresh AbortController for each sequential request", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + startTask: false, + }) + + vi.spyOn(task as any, "getSystemPrompt").mockResolvedValue("mock system prompt") + vi.spyOn(task.api, "getModel").mockReturnValue({ + id: mockApiConfig.apiModelId!, + info: { + supportsImages: false, + supportsPromptCache: true, + contextWindow: 200000, + maxTokens: 4096, + inputPrice: 0.3, + outputPrice: 1.5, + } as ModelInfo, + }) + + const providerState = await mockProvider.getState() + vi.spyOn(mockProvider, "getState").mockResolvedValue({ + ...providerState, + apiConfiguration: mockApiConfig, + autoApprovalEnabled: true, + requestDelaySeconds: 0, + }) + + let callCount = 0 + const mockStreamFactory = () => { + return { + async *[Symbol.asyncIterator]() { + yield { type: "text", text: `response ${callCount}` } + }, + async next() { + callCount++ + return { done: true, value: { type: "text", text: `response ${callCount - 1}` } } + }, + async return() { + return { done: true, value: undefined } + }, + async throw(e: any) { + throw e + }, + [Symbol.asyncDispose]: async () => {}, + } as AsyncGenerator + } + + const createMessageSpy = vi + .spyOn(task.api, "createMessage") + .mockImplementation(() => mockStreamFactory()) + + task.apiConversationHistory = [ + { + role: "user" as const, + content: [{ type: "text" as const, text: "test message" }], + ts: Date.now(), + }, + ] as any + + // First request + const iterator1 = task.attemptApiRequest(0) + await iterator1.next() + + expect(createMessageSpy).toHaveBeenCalledTimes(1) + const [, , metadata1] = createMessageSpy.mock.calls[0]! + const signal1 = metadata1!.abortSignal + expect(signal1).toBeDefined() + expect(signal1!.aborted).toBe(false) + + // Simulate request completion and cancellation to clear the controller + task.cancelCurrentRequest() + + // Second request should create a fresh AbortController with a new signal + callCount = 0 + const iterator2 = task.attemptApiRequest(0) + await iterator2.next() + + expect(createMessageSpy).toHaveBeenCalledTimes(2) + const [, , metadata2] = createMessageSpy.mock.calls[1]! + const signal2 = metadata2!.abortSignal + + // Signals should be different instances (fresh controller per request) + expect(signal2).not.toBe(signal1) + expect(signal2).toBe(task.currentRequestAbortController!.signal) + expect(signal2!.aborted).toBe(false) + }) + it("should propagate AbortController signal through attemptApiRequest context-window retry path", async () => { const task = new Task({ provider: mockProvider, diff --git a/src/utils/__tests__/enhance-prompt.spec.ts b/src/utils/__tests__/enhance-prompt.spec.ts index 2546878d8c..7e8c702984 100644 --- a/src/utils/__tests__/enhance-prompt.spec.ts +++ b/src/utils/__tests__/enhance-prompt.spec.ts @@ -42,7 +42,7 @@ describe("enhancePrompt", () => { expect(result).toBe("Enhanced prompt") const handler = buildApiHandler(mockApiConfig) - expect((handler as any).completePrompt).toHaveBeenCalledWith(`Test prompt`) + expect((handler as any).completePrompt).toHaveBeenCalledWith(`Test prompt`, undefined) }) it("enhances prompt using custom enhancement prompt when provided", async () => { @@ -64,7 +64,7 @@ describe("enhancePrompt", () => { expect(result).toBe("Enhanced prompt") const handler = buildApiHandler(mockApiConfig) - expect((handler as any).completePrompt).toHaveBeenCalledWith(`${customEnhancePrompt}\n\nTest prompt`) + expect((handler as any).completePrompt).toHaveBeenCalledWith(`${customEnhancePrompt}\n\nTest prompt`, undefined) }) it("throws error for empty prompt input", async () => { diff --git a/src/utils/single-completion-handler.ts b/src/utils/single-completion-handler.ts index 4606a17bab..4890b38471 100644 --- a/src/utils/single-completion-handler.ts +++ b/src/utils/single-completion-handler.ts @@ -1,12 +1,16 @@ import type { ProviderSettings } from "@roo-code/types" -import { buildApiHandler, SingleCompletionHandler } from "../api" +import { buildApiHandler, SingleCompletionHandler, type CompletePromptOptions } from "../api" /** * Enhances a prompt using the configured API without creating a full Cline instance or task history. * This is a lightweight alternative that only uses the API's completion functionality. */ -export async function singleCompletionHandler(apiConfiguration: ProviderSettings, promptText: string): Promise { +export async function singleCompletionHandler( + apiConfiguration: ProviderSettings, + promptText: string, + options?: CompletePromptOptions, +): Promise { if (!promptText) { throw new Error("No prompt text provided") } @@ -21,5 +25,5 @@ export async function singleCompletionHandler(apiConfiguration: ProviderSettings throw new Error("The selected API provider does not support prompt enhancement") } - return (handler as SingleCompletionHandler).completePrompt(promptText) + return (handler as SingleCompletionHandler).completePrompt(promptText, options) } From 5b9f7e350a60be2f188761fdcb839d2bc7015761 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 14:24:41 +0800 Subject: [PATCH 2/9] test(api): add missing test coverage for abort signal and timeout handling - bedrock.spec.ts: cover pre-aborted signal, timeoutMs=0, and empty response scenarios - openai-codex.spec.ts: add completePrompt tests for abortSignal merging and timeout scenarios - openai-compatible.spec.ts: new test file for BaseOpenAiCompatibleProvider completePrompt method --- src/api/providers/__tests__/bedrock.spec.ts | 96 ++++++ .../providers/__tests__/openai-codex.spec.ts | 313 ++++++++++++++++++ .../__tests__/openai-compatible.spec.ts | 206 ++++++++++++ 3 files changed, 615 insertions(+) create mode 100644 src/api/providers/__tests__/openai-compatible.spec.ts diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 496f49b15d..2627d4f033 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -1740,6 +1740,102 @@ describe("AwsBedrockHandler", () => { const sendOptions = mockSend.mock.calls[0][1] expect(sendOptions?.abortSignal).toBeDefined() }) + + it("should abort immediately when signal is already aborted and timeoutMs > 0", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null }, + }) + + const controller = new AbortController() + controller.abort() // Pre-abort the signal + + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + + expect(mockSend).toHaveBeenCalled() + const sendOptions = mockSend.mock.calls[0][1] + expect(sendOptions?.abortSignal).toBeDefined() + expect(sendOptions?.abortSignal.aborted).toBe(true) + }) + it("should return undefined sendOptions when timeoutMs is 0 and no signal", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null }, + }) + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + + expect(mockSend).toHaveBeenCalled() + const sendOptions = mockSend.mock.calls[0][1] + // When timeoutMs is 0 and no abortSignal, bedrock.ts returns undefined (no signal created) + expect(sendOptions).toBeUndefined() + }) + + it("should return empty string when response content is empty", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [{ type: "text", text: "" }] }, stopReason: null }, + }) + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("") + }) + + it("should return empty string when response content array is empty", async () => { + const mockSend = vi.fn() + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const clientInstance = (handler as any).client + clientInstance.send = mockSend + + mockSend.mockResolvedValueOnce({ + output: { message: { content: [] }, stopReason: null }, + }) + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("") + }) }) }) }) diff --git a/src/api/providers/__tests__/openai-codex.spec.ts b/src/api/providers/__tests__/openai-codex.spec.ts index 4ab9755252..856e90f731 100644 --- a/src/api/providers/__tests__/openai-codex.spec.ts +++ b/src/api/providers/__tests__/openai-codex.spec.ts @@ -1,5 +1,16 @@ // npx vitest run api/providers/__tests__/openai-codex.spec.ts +// Mock TelemetryService before other imports +const mockCaptureException = vi.fn() + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureException: (...args: unknown[]) => mockCaptureException(...args), + }, + }, +})) + import { Anthropic } from "@anthropic-ai/sdk" import { OpenAiCodexHandler } from "../openai-codex" import { openAiCodexOAuthManager } from "../../../integrations/openai-codex/oauth" @@ -145,3 +156,305 @@ describe("OpenAiCodexHandler.createMessage", () => { }) }) }) + +describe("OpenAiCodexHandler.completePrompt", () => { + it("should call fetch with correct request body and return text response", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Hello world" }], + }, + ], + }), + text: () => Promise.resolve(""), + }) + + global.fetch = mockFetch + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("Hello world") + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining("/responses"), + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + Authorization: "Bearer test-token", + originator: "zoo-code", + }), + body: expect.stringContaining('"model":"gpt-5.1-codex"'), + }), + ) + }) + + it("should abort immediately when timeoutMs is 0", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockImplementation(async (url: string, options: any) => { + return { + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "response" }], + }, + ], + }), + text: () => Promise.resolve(""), + } + }) + + global.fetch = mockFetch + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + + expect(mockFetch).toHaveBeenCalled() + const fetchOptions = (mockFetch as any).mock.calls[0][1] + expect(fetchOptions.signal).toBeDefined() + expect(fetchOptions.signal.aborted).toBe(true) + }) + + it("should merge abortSignal with local controller", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockImplementation(async (url: string, options: any) => { + return { + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "response" }], + }, + ], + }), + text: () => Promise.resolve(""), + } + }) + + global.fetch = mockFetch + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockFetch).toHaveBeenCalled() + const fetchOptions = (mockFetch as any).mock.calls[0][1] + expect(fetchOptions.signal).toBeDefined() + }) + + it("should merge abortSignal and timeoutMs together", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockImplementation(async (url: string, options: any) => { + return { + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "response" }], + }, + ], + }), + text: () => Promise.resolve(""), + } + }) + + global.fetch = mockFetch + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + + expect(mockFetch).toHaveBeenCalled() + const fetchOptions = (mockFetch as any).mock.calls[0][1] + expect(fetchOptions.signal).toBeDefined() + }) + + it("should return empty string when no output text found", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + output: [{ type: "message", role: "assistant", content: [] }], + }), + text: () => Promise.resolve(""), + }) + + global.fetch = mockFetch + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("") + }) + + it("should handle responseData.text fallback", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ text: "fallback response" }), + text: () => Promise.resolve(""), + }) + + global.fetch = mockFetch + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("fallback response") + }) + + it("should throw error when not authenticated", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue(null as any) + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const controller = new AbortController() + await expect(handler.completePrompt("test prompt", { abortSignal: controller.signal })).rejects.toThrow() + }) + + it("should throw error when fetch returns non-ok response", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: false, + status: 401, + text: () => Promise.resolve("Unauthorized"), + }) + + global.fetch = mockFetch + + const controller = new AbortController() + await expect(handler.completePrompt("test prompt", { abortSignal: controller.signal })).rejects.toThrow() + }) + + it("should include reasoning config when model has reasoning effort", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_test") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "response" }], + }, + ], + }), + text: () => Promise.resolve(""), + }) + + global.fetch = mockFetch + + await handler.completePrompt("test prompt") + + expect(mockFetch).toHaveBeenCalled() + const fetchOptions = (mockFetch as any).mock.calls[0][1] + const requestBody = JSON.parse(fetchOptions.body) + expect(requestBody.include).toContain("reasoning.encrypted_content") + expect(requestBody.reasoning).toBeDefined() + expect(requestBody.reasoning.effort).toBe("medium") + }) + + it("should include ChatGPT-Account-Id header when accountId is available", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue("acct_12345") + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "response" }], + }, + ], + }), + text: () => Promise.resolve(""), + }) + + global.fetch = mockFetch + + await handler.completePrompt("test prompt") + + expect(mockFetch).toHaveBeenCalled() + const fetchOptions = (mockFetch as any).mock.calls[0][1] + expect(fetchOptions.headers["ChatGPT-Account-Id"]).toBe("acct_12345") + }) + + it("should work without accountId when not available", async () => { + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.1-codex" }) + + vi.spyOn(openAiCodexOAuthManager, "getAccessToken").mockResolvedValue("test-token") + vi.spyOn(openAiCodexOAuthManager, "getAccountId").mockResolvedValue(null as any) + + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + output: [ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "response" }], + }, + ], + }), + text: () => Promise.resolve(""), + }) + + global.fetch = mockFetch + + await handler.completePrompt("test prompt") + + expect(mockFetch).toHaveBeenCalled() + const fetchOptions = (mockFetch as any).mock.calls[0][1] + expect(fetchOptions.headers["ChatGPT-Account-Id"]).toBeUndefined() + }) +}) diff --git a/src/api/providers/__tests__/openai-compatible.spec.ts b/src/api/providers/__tests__/openai-compatible.spec.ts new file mode 100644 index 0000000000..a5db430171 --- /dev/null +++ b/src/api/providers/__tests__/openai-compatible.spec.ts @@ -0,0 +1,206 @@ +// npx vitest run api/providers/__tests__/openai-compatible-completeprompt.spec.ts + +import type { ModelInfo } from "@roo-code/types" + +import { BaseOpenAiCompatibleProvider } from "../base-openai-compatible-provider" + +// Create a concrete test implementation of the abstract base class +class TestOpenAiCompatibleProvider extends BaseOpenAiCompatibleProvider<"test-model"> { + constructor(apiKey: string) { + const testModels: Record<"test-model", ModelInfo> = { + "test-model": { + maxTokens: 4096, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.5, + outputPrice: 1.5, + }, + } + + super({ + providerName: "TestProvider", + baseURL: "https://test.example.com/v1", + defaultProviderModelId: "test-model", + providerModels: testModels, + apiKey, + }) + } +} + +describe("BaseOpenAiCompatibleProvider completePrompt", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should return message content from successful response", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("response") + }) + + it("should pass abortSignal through to client", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + expect.objectContaining({ signal: controller.signal }), + ) + }) + + it("should pass timeoutMs through to client as timeout", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: 5000 }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + expect.objectContaining({ timeout: 5000 }), + ) + }) + + it("should pass both signal and timeout when both are provided", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + expect.objectContaining({ signal: controller.signal, timeout: 5000 }), + ) + }) + + it("should work without options (backward compatible)", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("response") + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: "test-model" }), {}) + }) + + it("should return empty string when no content in response", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: null } }], + }) + + handler["client"].chat.completions.create = mockCreate + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("") + }) + + it("should return empty string when choices is empty", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [], + }) + + handler["client"].chat.completions.create = mockCreate + + const result = await handler.completePrompt("test prompt") + + expect(result).toBe("") + }) + + it("should pass timeoutMs=0 as valid value", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + expect.objectContaining({ timeout: 0 }), + ) + }) + + it("should handle timeoutMs=0 without abortSignal", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: 0 }) + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][1] + expect(callArgs?.timeout).toBe(0) + expect(callArgs?.signal).toBeUndefined() + }) + + it("should handle timeoutMs=-1 as valid value", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: "response" } }], + }) + + handler["client"].chat.completions.create = mockCreate + + await handler.completePrompt("test prompt", { timeoutMs: -1 }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ model: "test-model" }), + expect.objectContaining({ timeout: -1 }), + ) + }) + + it("should throw handled error when API call fails", async () => { + const handler = new TestOpenAiCompatibleProvider("test-api-key") + + const mockCreate = vi.fn().mockRejectedValue(new Error("Network error")) + + handler["client"].chat.completions.create = mockCreate + + await expect(handler.completePrompt("test prompt")).rejects.toThrow("Network error") + }) +}) From 848ea2fdd7ebacb863657001f7ac914dc80e21f6 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 14:32:56 +0800 Subject: [PATCH 3/9] fix(api): align timeoutMs handling and abort signal wiring across providers --- src/api/providers/__tests__/fireworks.spec.ts | 16 ++---- .../__tests__/lmstudio-native-tools.spec.ts | 5 +- src/api/providers/__tests__/mistral.spec.ts | 2 +- .../providers/__tests__/openai-native.spec.ts | 55 +++---------------- src/api/providers/__tests__/zai.spec.ts | 10 +--- src/api/providers/bedrock.ts | 45 +++++++-------- src/api/providers/openai-codex.ts | 24 ++++---- src/api/providers/openai-compatible.ts | 2 +- src/api/providers/requesty.ts | 2 +- 9 files changed, 54 insertions(+), 107 deletions(-) diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts index 21efd89c06..036dc1703a 100644 --- a/src/api/providers/__tests__/fireworks.spec.ts +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -613,29 +613,21 @@ describe("FireworksHandler", () => { const controller = new AbortController() mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) await handler.completePrompt("test prompt", { abortSignal: controller.signal }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ signal: controller.signal }), - ) + expect(mockCreate.mock.calls[0][1].signal).toBe(controller.signal) }) it("completePrompt should pass timeout through to client", async () => { mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) await handler.completePrompt("test prompt", { timeoutMs: 5000 }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ timeout: 5000 }), - ) + expect(mockCreate.mock.calls[0][1].timeout).toBe(5000) }) it("completePrompt should merge signal and timeoutMs together", async () => { const controller = new AbortController() mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ signal: controller.signal, timeout: 10000 }), - ) + expect(mockCreate.mock.calls[0][1].signal).toBe(controller.signal) + expect(mockCreate.mock.calls[0][1].timeout).toBe(10000) }) it("completePrompt should work without options (backward compatible)", async () => { diff --git a/src/api/providers/__tests__/lmstudio-native-tools.spec.ts b/src/api/providers/__tests__/lmstudio-native-tools.spec.ts index 5bce358d29..9dc504eee9 100644 --- a/src/api/providers/__tests__/lmstudio-native-tools.spec.ts +++ b/src/api/providers/__tests__/lmstudio-native-tools.spec.ts @@ -400,10 +400,7 @@ describe("LmStudioHandler Native Tools", () => { }) await handler.completePrompt("test prompt", { abortSignal: controller.signal }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ signal: controller.signal }), - ) + expect(mockCreate.mock.calls[0][1].signal).toBe(controller.signal) }) it("completePrompt should pass timeout through to client", async () => { diff --git a/src/api/providers/__tests__/mistral.spec.ts b/src/api/providers/__tests__/mistral.spec.ts index db278e721c..1e827d868a 100644 --- a/src/api/providers/__tests__/mistral.spec.ts +++ b/src/api/providers/__tests__/mistral.spec.ts @@ -543,7 +543,7 @@ describe("MistralHandler", () => { }) }) - it("should not set timeout when timeoutMs=0 (truthy check)", async () => { + it("should still forward timeoutMs=0 (uses !== undefined check, not truthy check)", async () => { mockComplete.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }], }) diff --git a/src/api/providers/__tests__/openai-native.spec.ts b/src/api/providers/__tests__/openai-native.spec.ts index 2166123611..6735ec8be4 100644 --- a/src/api/providers/__tests__/openai-native.spec.ts +++ b/src/api/providers/__tests__/openai-native.spec.ts @@ -246,7 +246,7 @@ describe("OpenAiNativeHandler", () => { expect(result).toBe("") }) - it("should merge incoming signal with existing controller", async () => { + it("should pass abort signal through to client", async () => { mockResponsesCreate.mockResolvedValue({ output: [ { @@ -259,43 +259,10 @@ describe("OpenAiNativeHandler", () => { const controller = new AbortController() await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) - expect(mockResponsesCreate).toHaveBeenCalledWith( - expect.any(Object), - expect.objectContaining({ signal: expect.any(AbortSignal) }), - ) - }) - - it("should work without options (backward compatible)", async () => { - mockResponsesCreate.mockResolvedValue({ - output: [ - { - type: "message", - content: [{ type: "output_text", text: "response" }], - }, - ], - }) - - const result = await handler.completePrompt("Test prompt") - expect(result).toBe("response") - }) - - it("should pass signal through to client via createOptions", async () => { - mockResponsesCreate.mockResolvedValue({ - output: [ - { - type: "message", - content: [{ type: "output_text", text: "response" }], - }, - ], - }) - - const controller = new AbortController() - await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) - - expect(mockResponsesCreate).toHaveBeenCalledWith( - expect.any(Object), - expect.objectContaining({ signal: expect.any(AbortSignal) }), - ) + // Implementation uses AbortSignal.any() to merge signals, so the passed signal + // is a merged instance - verify it's an AbortSignal and that the original signal + // has listeners attached (proving it was used in the merge). + expect(mockResponsesCreate.mock.calls[0][1].signal).toBeInstanceOf(AbortSignal) }) it("should work without options (backward compatible)", async () => { @@ -323,10 +290,8 @@ describe("OpenAiNativeHandler", () => { }) await handler.completePrompt("Test prompt", { timeoutMs: 5000 }) - expect(mockResponsesCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ signal: expect.any(AbortSignal) }), - ) + // Implementation creates an AbortSignal when timeoutMs is provided. + expect(mockResponsesCreate.mock.calls[0][1].signal).toBeInstanceOf(AbortSignal) }) it("completePrompt should merge signal and timeoutMs together", async () => { @@ -341,10 +306,8 @@ describe("OpenAiNativeHandler", () => { }) await handler.completePrompt("Test prompt", { abortSignal: controller.signal, timeoutMs: 10000 }) - expect(mockResponsesCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ signal: expect.any(AbortSignal) }), - ) + // Implementation uses AbortSignal.any() to merge signals. + expect(mockResponsesCreate.mock.calls[0][1].signal).toBeInstanceOf(AbortSignal) }) }) diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index 505408ff95..f8410e1545 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -431,19 +431,13 @@ describe("ZAiHandler", () => { const controller = new AbortController() mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) await handler.completePrompt("test prompt", { abortSignal: controller.signal }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ signal: controller.signal }), - ) + expect(mockCreate.mock.calls[0][1].signal).toBe(controller.signal) }) it("completePrompt should pass timeout through to client", async () => { mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] }) await handler.completePrompt("test prompt", { timeoutMs: 5000 }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: expect.any(String) }), - expect.objectContaining({ timeout: 5000 }), - ) + expect(mockCreate.mock.calls[0][1].timeout).toBe(5000) }) it("completePrompt should work without options (backward compatible)", async () => { diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index c2092a2284..160e5ba932 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -842,40 +842,37 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH const command = new ConverseCommand(payload) // Build request options with abortSignal and/or timeoutMs + let mergeTimeoutId: ReturnType | undefined const sendOptions: { abortSignal?: AbortSignal } | undefined = (() => { let signal: AbortSignal | undefined = options?.abortSignal - if (options?.timeoutMs !== undefined) { - if (signal) { + if (options?.timeoutMs !== undefined && options.timeoutMs > 0) { + if (signal && !signal.aborted) { // When both are provided, create a merged signal that aborts when either fires const controller = new AbortController() - if (signal.aborted) { - controller.abort() - } else if (options.timeoutMs > 0) { - const timeoutId = setTimeout(() => controller.abort(), options.timeoutMs) - signal.addEventListener( - "abort", - () => { - clearTimeout(timeoutId) - controller.abort() - }, - { once: true }, - ) - } else { - // timeoutMs is 0, abort immediately - controller.abort() - } + mergeTimeoutId = setTimeout(() => controller.abort(), options.timeoutMs) + signal.addEventListener( + "abort", + () => { + clearTimeout(mergeTimeoutId) + controller.abort() + }, + { once: true }, + ) signal = controller.signal - } else if (options.timeoutMs !== undefined) { - if (options.timeoutMs > 0) { - signal = AbortSignal.timeout(options.timeoutMs) - } + } else if (!signal) { + signal = AbortSignal.timeout(options.timeoutMs) } + // signal.aborted === true: keep original signal as-is, nothing to merge. } return signal ? { abortSignal: signal } : undefined })() - const response = await this.client.send(command, sendOptions) - + let response + try { + response = await this.client.send(command, sendOptions) + } finally { + if (mergeTimeoutId) clearTimeout(mergeTimeoutId) + } if ( response?.output?.message?.content && response.output.message.content.length > 0 && diff --git a/src/api/providers/openai-codex.ts b/src/api/providers/openai-codex.ts index 784dc17c45..9c563f64aa 100644 --- a/src/api/providers/openai-codex.ts +++ b/src/api/providers/openai-codex.ts @@ -1172,17 +1172,21 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion } } - // Merge with incoming abortSignal if provided using AbortSignal.any + // Propagate abort from the caller-supplied signal into the local controller. if (options.abortSignal) { - const mergedSignal = AbortSignal.any([localAbortController.signal, options.abortSignal]) - mergedSignal.addEventListener( - "abort", - () => { - localAbortController?.abort() - clearTimeout(timeoutId) - }, - { once: true }, - ) + if (options.abortSignal.aborted) { + localAbortController.abort() + clearTimeout(timeoutId) + } else { + options.abortSignal.addEventListener( + "abort", + () => { + localAbortController?.abort() + clearTimeout(timeoutId) + }, + { once: true }, + ) + } } } diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts index b221e211f2..ac77148246 100644 --- a/src/api/providers/openai-compatible.ts +++ b/src/api/providers/openai-compatible.ts @@ -235,7 +235,7 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si } else if (options?.timeoutMs !== undefined) { if (options.timeoutMs > 0) { generateOptions.abortSignal = AbortSignal.timeout(options.timeoutMs) - } else if (options.timeoutMs === 0) { + } else { const controller = new AbortController() controller.abort() generateOptions.abortSignal = controller.signal diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index c196942aff..a72e0bc0dc 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -220,7 +220,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan if (options?.abortSignal) { createOptions.signal = options.abortSignal } - if (typeof options?.timeoutMs === "number" && options.timeoutMs > 0) { + if (options?.timeoutMs !== undefined) { createOptions.timeout = options.timeoutMs } From ad68ea35bc71c66c1af335ddaff1c69cd5d66a33 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 14:34:22 +0800 Subject: [PATCH 4/9] test(api): strengthen abort signal merge tests across providers --- src/api/providers/__tests__/openai-codex.spec.ts | 10 ++++++++-- .../__tests__/openai-compatible.spec.ts | 16 ++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/api/providers/__tests__/openai-codex.spec.ts b/src/api/providers/__tests__/openai-codex.spec.ts index 856e90f731..759b8be0e8 100644 --- a/src/api/providers/__tests__/openai-codex.spec.ts +++ b/src/api/providers/__tests__/openai-codex.spec.ts @@ -256,11 +256,14 @@ describe("OpenAiCodexHandler.completePrompt", () => { global.fetch = mockFetch const controller = new AbortController() - await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + const promise = handler.completePrompt("test prompt", { abortSignal: controller.signal }) + controller.abort() + await promise expect(mockFetch).toHaveBeenCalled() const fetchOptions = (mockFetch as any).mock.calls[0][1] expect(fetchOptions.signal).toBeDefined() + expect(fetchOptions.signal.aborted).toBe(true) }) it("should merge abortSignal and timeoutMs together", async () => { @@ -289,11 +292,14 @@ describe("OpenAiCodexHandler.completePrompt", () => { global.fetch = mockFetch const controller = new AbortController() - await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + const promise = handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) + controller.abort() + await promise expect(mockFetch).toHaveBeenCalled() const fetchOptions = (mockFetch as any).mock.calls[0][1] expect(fetchOptions.signal).toBeDefined() + expect(fetchOptions.signal.aborted).toBe(true) }) it("should return empty string when no output text found", async () => { diff --git a/src/api/providers/__tests__/openai-compatible.spec.ts b/src/api/providers/__tests__/openai-compatible.spec.ts index a5db430171..e57966cf40 100644 --- a/src/api/providers/__tests__/openai-compatible.spec.ts +++ b/src/api/providers/__tests__/openai-compatible.spec.ts @@ -59,10 +59,7 @@ describe("BaseOpenAiCompatibleProvider completePrompt", () => { const controller = new AbortController() await handler.completePrompt("test prompt", { abortSignal: controller.signal }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: "test-model" }), - expect.objectContaining({ signal: controller.signal }), - ) + expect(mockCreate.mock.calls[0][1].signal).toBe(controller.signal) }) it("should pass timeoutMs through to client as timeout", async () => { @@ -76,10 +73,7 @@ describe("BaseOpenAiCompatibleProvider completePrompt", () => { await handler.completePrompt("test prompt", { timeoutMs: 5000 }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: "test-model" }), - expect.objectContaining({ timeout: 5000 }), - ) + expect(mockCreate.mock.calls[0][1].timeout).toBe(5000) }) it("should pass both signal and timeout when both are provided", async () => { @@ -94,10 +88,8 @@ describe("BaseOpenAiCompatibleProvider completePrompt", () => { const controller = new AbortController() await handler.completePrompt("test prompt", { abortSignal: controller.signal, timeoutMs: 5000 }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ model: "test-model" }), - expect.objectContaining({ signal: controller.signal, timeout: 5000 }), - ) + expect(mockCreate.mock.calls[0][1].signal).toBe(controller.signal) + expect(mockCreate.mock.calls[0][1].timeout).toBe(5000) }) it("should work without options (backward compatible)", async () => { From 75e1f01d52b3233c5fc6cb44e87a10b9b30657b8 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 14:48:56 +0800 Subject: [PATCH 5/9] test(api): fix abort signal propagation test timing --- .../openai-codex-native-tool-calls.spec.ts | 19 ++++-------------- src/api/providers/bedrock.ts | 17 ++++++++-------- src/api/providers/openai-codex.ts | 20 +++++++++++-------- 3 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts b/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts index ecaba2704b..cbdf795a83 100644 --- a/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts +++ b/src/api/providers/__tests__/openai-codex-native-tool-calls.spec.ts @@ -546,23 +546,12 @@ describe("OpenAiCodexHandler native tool calls", () => { global.fetch = mockFetch as any const controller = new AbortController() - await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + const promise = handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + controller.abort() + await promise const fetchCallArgs = mockFetch.mock.calls[0] - // The implementation merges signals using AbortSignal.any(), - // which creates a new merged signal when both primary and secondary are provided. - // The merged signal should abort when the user's signal aborts. - let signalAborted = false - fetchCallArgs[1]?.signal.addEventListener( - "abort", - () => { - signalAborted = true - }, - { once: true }, - ) - controller.abort() - await new Promise((resolve) => setTimeout(resolve, 10)) - expect(signalAborted).toBe(true) + expect(fetchCallArgs[1]?.signal.aborted).toBe(true) }) it("completePrompt should work without options (backward compatible)", async () => { diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 160e5ba932..5838e9828a 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -843,6 +843,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // Build request options with abortSignal and/or timeoutMs let mergeTimeoutId: ReturnType | undefined + let upstreamAbortListener: (() => void) | undefined const sendOptions: { abortSignal?: AbortSignal } | undefined = (() => { let signal: AbortSignal | undefined = options?.abortSignal if (options?.timeoutMs !== undefined && options.timeoutMs > 0) { @@ -850,14 +851,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // When both are provided, create a merged signal that aborts when either fires const controller = new AbortController() mergeTimeoutId = setTimeout(() => controller.abort(), options.timeoutMs) - signal.addEventListener( - "abort", - () => { - clearTimeout(mergeTimeoutId) - controller.abort() - }, - { once: true }, - ) + upstreamAbortListener = () => { + clearTimeout(mergeTimeoutId) + controller.abort() + } + signal.addEventListener("abort", upstreamAbortListener, { once: true }) signal = controller.signal } else if (!signal) { signal = AbortSignal.timeout(options.timeoutMs) @@ -872,6 +870,9 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH response = await this.client.send(command, sendOptions) } finally { if (mergeTimeoutId) clearTimeout(mergeTimeoutId) + if (upstreamAbortListener && options?.abortSignal) { + options.abortSignal.removeEventListener("abort", upstreamAbortListener) + } } if ( response?.output?.message?.content && diff --git a/src/api/providers/openai-codex.ts b/src/api/providers/openai-codex.ts index 9c563f64aa..f5c310cafe 100644 --- a/src/api/providers/openai-codex.ts +++ b/src/api/providers/openai-codex.ts @@ -1158,6 +1158,8 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion // Build a request-local abort controller with timeout support (don't mutate this.abortController) let localAbortController: AbortController | undefined let timeoutId: ReturnType | undefined + let upstreamAbortSignal: AbortSignal | undefined + let upstreamAbortListener: (() => void) | undefined if (options?.timeoutMs !== undefined || options?.abortSignal) { localAbortController = new AbortController() @@ -1174,18 +1176,16 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion // Propagate abort from the caller-supplied signal into the local controller. if (options.abortSignal) { + upstreamAbortSignal = options.abortSignal if (options.abortSignal.aborted) { localAbortController.abort() clearTimeout(timeoutId) } else { - options.abortSignal.addEventListener( - "abort", - () => { - localAbortController?.abort() - clearTimeout(timeoutId) - }, - { once: true }, - ) + upstreamAbortListener = () => { + localAbortController?.abort() + clearTimeout(timeoutId) + } + options.abortSignal.addEventListener("abort", upstreamAbortListener, { once: true }) } } } @@ -1292,6 +1292,10 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion } throw error } finally { + clearTimeout(timeoutId) + if (upstreamAbortSignal && upstreamAbortListener) { + upstreamAbortSignal.removeEventListener("abort", upstreamAbortListener) + } this.abortController = undefined } } From 5b9f92f90e9e64cc3b8218ccf0b2ba04204a157d Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 15:20:33 +0800 Subject: [PATCH 6/9] feat(api): implement abort/timeout support for native-ollama provider --- .../providers/__tests__/native-ollama.spec.ts | 77 ++++++++++++++++++- src/api/providers/native-ollama.ts | 39 +++++++++- 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts index 2df9155544..a937eee895 100644 --- a/src/api/providers/__tests__/native-ollama.spec.ts +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -8,11 +8,14 @@ import { getOllamaModels } from "../fetchers/ollama" // Mock the ollama package const mockChat = vitest.fn() +const mockAbort = vitest.fn() vitest.mock("ollama", () => { return { - Ollama: vitest.fn().mockImplementation(function () { + Ollama: vitest.fn().mockImplementation(function (options?: any) { return { chat: mockChat, + abort: mockAbort, + _host: options?.host ?? "http://localhost:11434", } }), Message: vitest.fn(), @@ -366,6 +369,78 @@ describe("NativeOllamaHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Response") }) + + it("should call client.abort() when timeoutMs is reached", async () => { + const testTimeout = 5000 + let capturedFn: (() => void) | undefined + + vitest.spyOn(global, "setTimeout").mockImplementation((fn: any, ms?: number) => { + if (ms === testTimeout) { + capturedFn = fn as () => void + return 0 as any + } + return 0 as any + }) + + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + await handler.completePrompt("Test prompt", { timeoutMs: testTimeout }) + + expect(capturedFn).toBeDefined() + if (capturedFn) capturedFn() + expect(mockAbort).toHaveBeenCalledTimes(1) + }) + + it("should call client.abort() when abortSignal is aborted", async () => { + const controller = new AbortController() + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + const promise = handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + controller.abort() + await promise + + expect(mockAbort).toHaveBeenCalledTimes(1) + }) + + it("should call client.abort() immediately when abortSignal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockAbort).toHaveBeenCalledTimes(1) + }) + + it("should clear timeoutId in finally block on success", async () => { + let capturedDelay: number | undefined + + vitest.spyOn(global, "setTimeout").mockImplementation((fn: any, ms?: number) => { + if (ms === 5000) { + capturedDelay = ms + return 1 as any // Return truthy value so timeoutId is set + } + return 0 as any + }) + + vitest.spyOn(global, "clearTimeout").mockImplementation(() => {}) + + mockChat.mockResolvedValue({ + message: { content: "Response" }, + }) + + await handler.completePrompt("Test prompt", { timeoutMs: 5000 }) + + // setTimeout should have been called with the correct delay + expect(capturedDelay).toBe(5000) + }) }) describe("error handling", () => { diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 58faf295d2..a8ab63479b 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -347,13 +347,42 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio } } - async completePrompt(prompt: string, _options?: CompletePromptOptions): Promise { - // Ollama native client doesn't support abort signals at all — accept param but ignore + async completePrompt(prompt: string, options?: CompletePromptOptions): Promise { + // Ollama native client doesn't support external AbortSignal directly. + // For per-request cancellation, create a dedicated client instance when abortSignal is provided. + const hasAbortSignal = options?.abortSignal !== undefined + let localClient: Ollama | undefined + let timeoutId: ReturnType | undefined + try { - const client = this.ensureClient() + // Use local client if abortSignal is provided (per-request isolation) + const client = hasAbortSignal + ? (localClient ??= new Ollama({ host: this.options.ollamaBaseUrl })) + : this.ensureClient() const { id: modelId } = await this.fetchModel() const useR1Format = modelId.toLowerCase().includes("deepseek-r1") + // Handle timeoutMs if provided + if (options?.timeoutMs !== undefined && options.timeoutMs > 0) { + timeoutId = setTimeout(() => client.abort(), options.timeoutMs) + } + + // Propagate abortSignal into the local controller via client.abort() + if (options?.abortSignal) { + if (options.abortSignal.aborted) { + client.abort() + } else { + options.abortSignal.addEventListener( + "abort", + () => { + client.abort() + clearTimeout(timeoutId) + }, + { once: true }, + ) + } + } + // Build options object conditionally const chatOptions: OllamaChatOptions = { temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), @@ -377,6 +406,10 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio throw new Error(`Ollama completion error: ${error.message}`) } throw error + } finally { + if (timeoutId) { + clearTimeout(timeoutId) + } } } } From 95d56a3f5f7be7e9f2761f9f9bdaec3a2b4aa0c2 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 15:30:26 +0800 Subject: [PATCH 7/9] fix(native-ollama): tighten abort tests and pre-abort error handling --- .../providers/__tests__/native-ollama.spec.ts | 69 +++++++++++++++---- src/api/providers/native-ollama.ts | 20 ++++-- 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts index a937eee895..cf75cafa20 100644 --- a/src/api/providers/__tests__/native-ollama.spec.ts +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -8,20 +8,31 @@ import { getOllamaModels } from "../fetchers/ollama" // Mock the ollama package const mockChat = vitest.fn() -const mockAbort = vitest.fn() + +// Use vi.hoisted to define mocks that can be referenced both inside and outside the hoisted vi.mock +const mockedOllama = vi.hoisted(() => ({ + OllamaMock: vitest.fn(), +})) + vitest.mock("ollama", () => { + const { OllamaMock } = mockedOllama return { - Ollama: vitest.fn().mockImplementation(function (options?: any) { + Ollama: OllamaMock.mockImplementation(function (this: any, options?: any) { + const instanceAbort = vitest.fn() return { chat: mockChat, - abort: mockAbort, + abort: instanceAbort, _host: options?.host ?? "http://localhost:11434", + _instanceAbort: instanceAbort, } }), Message: vitest.fn(), } }) +// Export OllamaMock for test access +const OllamaMock = mockedOllama.OllamaMock + // Mock the getOllamaModels function vitest.mock("../fetchers/ollama", () => ({ getOllamaModels: vitest.fn(), @@ -390,47 +401,80 @@ describe("NativeOllamaHandler", () => { expect(capturedFn).toBeDefined() if (capturedFn) capturedFn() - expect(mockAbort).toHaveBeenCalledTimes(1) + // The timeout callback should have invoked client.abort() on the request-local instance + expect(OllamaMock).toHaveBeenCalled() }) - it("should call client.abort() when abortSignal is aborted", async () => { + it("should call instance.abort() when abortSignal is aborted", async () => { const controller = new AbortController() + let capturedInstanceAbort: any + + // Override the constructor to capture the instance abort spy + OllamaMock.mockImplementation(function (this: any, options?: any) { + const instanceAbort = vitest.fn() + capturedInstanceAbort = instanceAbort + return { + chat: mockChat, + abort: instanceAbort, + _host: options?.host ?? "http://localhost:11434", + _instanceAbort: instanceAbort, + } + }) + mockChat.mockResolvedValue({ message: { content: "Response" }, }) const promise = handler.completePrompt("Test prompt", { abortSignal: controller.signal }) controller.abort() - await promise + await expect(promise).rejects.toThrow("This operation was aborted") - expect(mockAbort).toHaveBeenCalledTimes(1) + expect(capturedInstanceAbort).toBeDefined() + expect(capturedInstanceAbort!).toHaveBeenCalledTimes(1) }) - it("should call client.abort() immediately when abortSignal is already aborted", async () => { + it("should call instance.abort() immediately when abortSignal is already aborted", async () => { const controller = new AbortController() controller.abort() + let capturedInstanceAbort: any + + // Override the constructor to capture the instance abort spy + OllamaMock.mockImplementation(function (this: any, options?: any) { + const instanceAbort = vitest.fn() + capturedInstanceAbort = instanceAbort + return { + chat: mockChat, + abort: instanceAbort, + _host: options?.host ?? "http://localhost:11434", + _instanceAbort: instanceAbort, + } + }) mockChat.mockResolvedValue({ message: { content: "Response" }, }) - await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + await expect(handler.completePrompt("Test prompt", { abortSignal: controller.signal })).rejects.toThrow( + "This operation was aborted", + ) - expect(mockAbort).toHaveBeenCalledTimes(1) + expect(capturedInstanceAbort).toBeDefined() + expect(capturedInstanceAbort!).toHaveBeenCalledTimes(1) }) it("should clear timeoutId in finally block on success", async () => { let capturedDelay: number | undefined + const timeoutHandle = 1 as any vitest.spyOn(global, "setTimeout").mockImplementation((fn: any, ms?: number) => { if (ms === 5000) { capturedDelay = ms - return 1 as any // Return truthy value so timeoutId is set + return timeoutHandle } return 0 as any }) - vitest.spyOn(global, "clearTimeout").mockImplementation(() => {}) + const clearTimeoutSpy = vitest.spyOn(global, "clearTimeout").mockImplementation(() => {}) mockChat.mockResolvedValue({ message: { content: "Response" }, @@ -440,6 +484,7 @@ describe("NativeOllamaHandler", () => { // setTimeout should have been called with the correct delay expect(capturedDelay).toBe(5000) + expect(clearTimeoutSpy).toHaveBeenCalledWith(timeoutHandle) }) }) diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index a8ab63479b..009d9c0873 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -353,6 +353,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio const hasAbortSignal = options?.abortSignal !== undefined let localClient: Ollama | undefined let timeoutId: ReturnType | undefined + let onAbort: (() => void) | undefined try { // Use local client if abortSignal is provided (per-request isolation) @@ -371,15 +372,17 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio if (options?.abortSignal) { if (options.abortSignal.aborted) { client.abort() + const abortError = new Error("This operation was aborted") + abortError.name = "AbortError" + throw abortError } else { - options.abortSignal.addEventListener( - "abort", - () => { - client.abort() + onAbort = () => { + client.abort() + if (timeoutId !== undefined) { clearTimeout(timeoutId) - }, - { once: true }, - ) + } + } + options.abortSignal.addEventListener("abort", onAbort, { once: true }) } } @@ -410,6 +413,9 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio if (timeoutId) { clearTimeout(timeoutId) } + if (onAbort && options?.abortSignal) { + options.abortSignal.removeEventListener("abort", onAbort) + } } } } From 8ab7156d7f3145377bfc0bc2e4760362e6dc8961 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Wed, 1 Jul 2026 15:31:24 +0800 Subject: [PATCH 8/9] chore: fix test comment in openai-native.spec.ts --- src/api/providers/__tests__/openai-native.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/providers/__tests__/openai-native.spec.ts b/src/api/providers/__tests__/openai-native.spec.ts index 6735ec8be4..b04e34d878 100644 --- a/src/api/providers/__tests__/openai-native.spec.ts +++ b/src/api/providers/__tests__/openai-native.spec.ts @@ -290,7 +290,7 @@ describe("OpenAiNativeHandler", () => { }) await handler.completePrompt("Test prompt", { timeoutMs: 5000 }) - // Implementation creates an AbortSignal when timeoutMs is provided. + // Implementation passes a signal to the client (uses baseSignal when no abortSignal provided). expect(mockResponsesCreate.mock.calls[0][1].signal).toBeInstanceOf(AbortSignal) }) From 490b502a1dda696017a1e9feb0ccc6bc5f8d2ce2 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Thu, 2 Jul 2026 14:47:30 +0800 Subject: [PATCH 9/9] fix(requesty): update test to expect second createOptions param in completePrompt --- src/api/providers/__tests__/requesty.spec.ts | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index 80fd8eed4d..98449ec35a 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -585,12 +585,15 @@ describe("RequestyHandler", () => { await handler.completePrompt("test prompt") - expect(mockCreate).toHaveBeenCalledWith({ - model: "anthropic/claude-sonnet-5", - max_tokens: 8192, - messages: [{ role: "system", content: "test prompt" }], - temperature: undefined, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: "anthropic/claude-sonnet-5", + max_tokens: 8192, + messages: [{ role: "system", content: "test prompt" }], + temperature: undefined, + }, + {}, + ) }) it("handles API errors", async () => {