Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions agent_core/core/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
"""

import logging
from typing import Optional
import urllib.request
import json as _json

from openai import OpenAI
from anthropic import Anthropic
from typing import Optional

try:
import boto3 # type: ignore[import]
except ImportError: # pragma: no cover — boto3 is an optional extra
Expand Down Expand Up @@ -53,6 +50,51 @@
_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"


_PROVIDER_DISPLAY = {
"openai": "OpenAI",
"deepseek": "DeepSeek",
"grok": "Grok",
"moonshot": "Moonshot",
"minimax": "MiniMax",
"openrouter": "OpenRouter",
}


def _create_openai_client(
*,
provider: str,
api_key: str,
base_url: Optional[str] = None,
):
"""Create an OpenAI SDK client for OpenAI-compatible providers."""
try:
from openai import OpenAI
except ImportError as exc:
display = _PROVIDER_DISPLAY.get(provider, provider)
raise ImportError(
f"The openai package is required for {display} because CraftBot "
"uses the OpenAI-compatible SDK client for this provider. "
"Install it with the Python that launches CraftBot: "
"`python -m pip install 'openai>=2.0.0'`."
) from exc

if base_url:
return OpenAI(api_key=api_key, base_url=base_url)
return OpenAI(api_key=api_key)


def _create_anthropic_client(*, api_key: str):
try:
from anthropic import Anthropic
except ImportError as exc:
raise ImportError(
"The anthropic package is required for the Anthropic provider. "
"Install it with the Python that launches CraftBot: "
"`python -m pip install 'anthropic>=0.97.0'`."
) from exc
return Anthropic(api_key=api_key)


def _to_openrouter_slug(provider: str, model: str) -> str:
"""Convert a provider-native model ID to its OpenRouter slug."""
if "/" in model:
Expand Down Expand Up @@ -120,7 +162,7 @@ def create(
Returns:
Dictionary with provider context including client instances
"""
# OpenAI-compatible providers that use OpenAI client with a custom base_url
# OpenAI-compatible providers that use chat-completions with a custom base_url.
_OPENAI_COMPAT = {"minimax", "deepseek", "moonshot", "grok", "openrouter"}

if provider not in PROVIDER_CONFIG:
Expand Down Expand Up @@ -175,7 +217,10 @@ def create(
return {
"provider": provider,
"model": model,
"client": OpenAI(api_key=api_key),
"client": _create_openai_client(
provider=provider,
api_key=api_key,
),
"gemini_client": None,
"remote_url": None,
"byteplus": None,
Expand Down Expand Up @@ -215,7 +260,7 @@ def create(
"gemini_client": None,
"remote_url": None,
"byteplus": None,
"anthropic_client": Anthropic(api_key=api_key),
"anthropic_client": _create_anthropic_client(api_key=api_key),
"bedrock_client": None,
"initialized": True,
}
Expand Down Expand Up @@ -273,7 +318,11 @@ def create(
return {
"provider": "openrouter",
"model": or_model,
"client": OpenAI(api_key=or_key, base_url=_OPENROUTER_BASE_URL),
"client": _create_openai_client(
provider="openrouter",
api_key=or_key,
base_url=_OPENROUTER_BASE_URL,
),
"gemini_client": None,
"remote_url": None,
"byteplus": None,
Expand All @@ -290,7 +339,11 @@ def create(
return {
"provider": provider,
"model": model,
"client": OpenAI(api_key=api_key, base_url=resolved_base_url),
"client": _create_openai_client(
provider=provider,
api_key=api_key,
base_url=resolved_base_url,
),
"gemini_client": None,
"remote_url": None,
"byteplus": None,
Expand Down
111 changes: 111 additions & 0 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
from pathlib import Path
import subprocess
import sys
import textwrap

PROJECT_ROOT = Path(__file__).resolve().parents[1]


def _run_with_blocked_sdks(code: str) -> subprocess.CompletedProcess:
script = textwrap.dedent(
f"""
import importlib.abc
import sys

class BlockProviderSdks(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path=None, target=None):
if (
fullname == "openai"
or fullname.startswith("openai.")
or fullname == "anthropic"
or fullname.startswith("anthropic.")
):
raise ImportError(f"{{fullname}} intentionally blocked")
return None

for name in list(sys.modules):
if (
name == "openai"
or name.startswith("openai.")
or name == "anthropic"
or name.startswith("anthropic.")
):
del sys.modules[name]
sys.meta_path.insert(0, BlockProviderSdks())

{textwrap.indent(textwrap.dedent(code), " ")}
"""
)
return subprocess.run(
[sys.executable, "-c", script],
cwd=PROJECT_ROOT,
text=True,
capture_output=True,
)


def test_importing_model_factory_does_not_require_provider_sdks():
result = _run_with_blocked_sdks(
"""
from agent_core.core.models.factory import ModelFactory
assert ModelFactory is not None
"""
)

assert result.returncode == 0, result.stderr


def test_deferred_openai_compatible_providers_do_not_require_openai_sdk():
result = _run_with_blocked_sdks(
"""
from agent_core.core.models.factory import ModelFactory
from agent_core.core.models.types import InterfaceType

for provider in ("deepseek", "grok", "moonshot", "minimax", "openrouter"):
ctx = ModelFactory.create(
provider=provider,
interface=InterfaceType.LLM,
deferred=True,
)
assert ctx["initialized"] is False
assert ctx["client"] is None
"""
)

assert result.returncode == 0, result.stderr


def test_openai_compatible_providers_report_missing_openai_sdk():
result = _run_with_blocked_sdks(
"""
from agent_core.core.models.factory import ModelFactory
from agent_core.core.models.types import InterfaceType

providers = {
"deepseek": "DeepSeek",
"grok": "Grok",
"moonshot": "Moonshot",
"minimax": "MiniMax",
"openrouter": "OpenRouter",
}

for provider, display in providers.items():
try:
ModelFactory.create(
provider=provider,
interface=InterfaceType.LLM,
api_key=f"{provider}-key",
)
except ImportError as exc:
message = str(exc)
assert "openai package is required" in message
assert display in message
else:
raise AssertionError(
f"expected missing OpenAI SDK to raise ImportError for {provider}"
)
"""
)

assert result.returncode == 0, result.stderr