Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions ai_scientist/copilot_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
GitHub Copilot API 认证管理器
- OAuth token → Copilot JWT 交换
- JWT 自动刷新(30分钟有效期)
"""

import os
import time
import threading
import requests
import logging

logger = logging.getLogger("copilot-auth")

_lock = threading.Lock()
_cached_jwt: str = ""
_jwt_expires_at: float = 0.0

COPILOT_API = "https://api.individual.githubcopilot.com"


def get_copilot_jwt() -> str:
"""获取有效的 Copilot JWT,过期前自动刷新"""
global _cached_jwt, _jwt_expires_at

with _lock:
# 还有至少 2 分钟有效期就复用
if _cached_jwt and time.time() < _jwt_expires_at - 120:
return _cached_jwt

oauth_token = os.environ.get("COPILOT_OAUTH_TOKEN", "")
if not oauth_token:
raise RuntimeError("COPILOT_OAUTH_TOKEN 未设置,无法获取 Copilot JWT")

headers = {
"Authorization": "token " + oauth_token,
"Accept": "application/json",
"Editor-Version": "vscode/1.100.0",
"Editor-Plugin-Version": "copilot/1.388.0",
"User-Agent": "GithubCopilot/1.388.0",
}

r = requests.get(
"https://api.github.com/copilot_internal/v2/token",
headers=headers,
timeout=15,
)
r.raise_for_status()
data = r.json()

_cached_jwt = data["token"]
_jwt_expires_at = float(data.get("expires_at", time.time() + 1800))
logger.info(f"Copilot JWT refreshed, expires at {_jwt_expires_at}")
return _cached_jwt


def copilot_default_headers() -> dict:
"""Copilot API 所需的额外 HTTP 请求头"""
return {
"Editor-Version": "vscode/1.100.0",
"Copilot-Integration-Id": "vscode-chat",
"User-Agent": "GithubCopilot/1.388.0",
}
9 changes: 9 additions & 0 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,15 @@ def extract_json_between_markers(llm_output: str) -> dict | None:


def create_client(model) -> tuple[Any, str]:
if os.environ.get("COPILOT_OAUTH_TOKEN"):
from ai_scientist.copilot_auth import get_copilot_jwt, copilot_default_headers
jwt = get_copilot_jwt()
print(f"Using GitHub Copilot API with model {model}.")
return openai.OpenAI(
api_key=jwt,
base_url="https://api.individual.githubcopilot.com",
default_headers=copilot_default_headers(),
), model
if model.startswith("claude-"):
print(f"Using Anthropic API with model {model}.")
return anthropic.Anthropic(), model
Expand Down
27 changes: 14 additions & 13 deletions ai_scientist/treesearch/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import os
from . import backend_anthropic, backend_openai
from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md

def get_ai_client(model: str, **model_kwargs):
"""
Get the appropriate AI client based on the model string.
def _use_openai_backend(model: str) -> bool:
"""Use OpenAI backend for OpenAI models and third-party proxies."""
# Copilot API uses OpenAI-compatible endpoint
if os.environ.get("COPILOT_OAUTH_TOKEN"):
return True
if "claude-" in model and not os.environ.get("ANTHROPIC_BASE_URL"):
return False
return True

Args:
model (str): string identifier for the model to use (e.g. "gpt-4-turbo")
**model_kwargs: Additional keyword arguments for model configuration.
Returns:
An instance of the appropriate AI client.
"""
if "claude-" in model:
return backend_anthropic.get_ai_client(model=model, **model_kwargs)
else:
def get_ai_client(model: str, **model_kwargs):
if _use_openai_backend(model):
return backend_openai.get_ai_client(model=model, **model_kwargs)
else:
return backend_anthropic.get_ai_client(model=model, **model_kwargs)

def query(
system_message: PromptType | None,
Expand Down Expand Up @@ -66,7 +67,7 @@ def query(
else:
model_kwargs["max_tokens"] = max_tokens

query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
query_func = backend_openai.query if _use_openai_backend(model) else backend_anthropic.query
output, req_time, in_tok_count, out_tok_count, info = query_func(
system_message=compile_prompt_to_md(system_message) if system_message else None,
user_message=compile_prompt_to_md(user_message) if user_message else None,
Expand Down
9 changes: 7 additions & 2 deletions ai_scientist/treesearch/backend/backend_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
anthropic.APIStatusError,
)

def get_ai_client(model : str, max_retries=2) -> anthropic.AnthropicBedrock:
client = anthropic.AnthropicBedrock(max_retries=max_retries)
def get_ai_client(model: str, max_retries=2):
base_url = os.environ.get("ANTHROPIC_BASE_URL", None)
api_key = os.environ.get("ANTHROPIC_API_KEY", None)
if base_url and api_key:
client = anthropic.Anthropic(api_key=api_key, base_url=base_url, max_retries=max_retries)
else:
client = anthropic.AnthropicBedrock(max_retries=max_retries)
return client

def query(
Expand Down
63 changes: 61 additions & 2 deletions ai_scientist/treesearch/backend/backend_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import os
import time
import threading

from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
Expand All @@ -9,16 +11,36 @@

logger = logging.getLogger("ai-scientist")

# Global rate limiter for Copilot API
_rate_lock = threading.Lock()
_last_request_time = 0.0
_COPILOT_MIN_INTERVAL = 3.0 # minimum seconds between requests


OPENAI_TIMEOUT_EXCEPTIONS = (
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
openai.PermissionDeniedError, # Some APIs return 403 on quota exhaustion
)

def _is_copilot_mode() -> bool:
"""Check if we're using GitHub Copilot API."""
return bool(os.environ.get("COPILOT_OAUTH_TOKEN"))


def get_ai_client(model: str, max_retries=2) -> openai.OpenAI:
if model.startswith("ollama/"):
if _is_copilot_mode():
from ai_scientist.copilot_auth import get_copilot_jwt, copilot_default_headers, COPILOT_API
jwt = get_copilot_jwt()
client = openai.OpenAI(
api_key=jwt,
base_url=COPILOT_API,
default_headers=copilot_default_headers(),
max_retries=max_retries,
)
elif model.startswith("ollama/"):
client = openai.OpenAI(
base_url="http://localhost:11434/v1",
max_retries=max_retries
Expand All @@ -39,15 +61,38 @@ def query(

messages = opt_messages_to_list(system_message, user_message)

# Claude via OpenAI-compatible proxy requires at least one user message.
# The proxy extracts system role into Anthropic's system parameter,
# leaving no messages if only system was provided.
if messages and not any(m.get("role") == "user" for m in messages):
messages.append({"role": "user", "content": "Please proceed with the task described above."})

# Claude requires max_tokens to be set explicitly
model_name = filtered_kwargs.get("model", "")
if "claude" in model_name and "max_tokens" not in filtered_kwargs:
filtered_kwargs["max_tokens"] = 8192

if func_spec is not None:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
# force the model to use the function
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

if filtered_kwargs.get("model", "").startswith("ollama/"):
filtered_kwargs["model"] = filtered_kwargs["model"].replace("ollama/", "")
filtered_kwargs["model"] = filtered_kwargs["model"].replace("ollama/", "", 1)

t0 = time.time()

# Copilot API rate limiting: space out requests
if _is_copilot_mode():
global _last_request_time
with _rate_lock:
elapsed = time.time() - _last_request_time
if elapsed < _COPILOT_MIN_INTERVAL:
wait = _COPILOT_MIN_INTERVAL - elapsed
logger.debug(f"Rate limit sleep {wait:.1f}s")
time.sleep(wait)
_last_request_time = time.time()

completion = backoff_create(
client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
Expand All @@ -56,9 +101,23 @@ def query(
)
req_time = time.time() - t0

# Copilot Claude may split content and tool_calls into separate choices.
# Find the right choice depending on whether we expect a function call.
choice = completion.choices[0]
if func_spec is not None and not getattr(choice.message, 'tool_calls', None):
# Look for a choice that has tool_calls
for c in completion.choices:
if getattr(c.message, 'tool_calls', None):
choice = c
break

if func_spec is None:
# For plain text, prefer the choice with content
if not choice.message.content:
for c in completion.choices:
if c.message.content:
choice = c
break
output = choice.message.content
else:
assert (
Expand Down
5 changes: 3 additions & 2 deletions ai_scientist/treesearch/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@backoff.on_predicate(
wait_gen=backoff.expo,
max_value=60,
factor=1.5,
max_value=120,
factor=2.0,
)
def backoff_create(
create_fn: Callable, retry_exceptions: list[Exception], *args, **kwargs
Expand All @@ -27,6 +27,7 @@ def backoff_create(
return create_fn(*args, **kwargs)
except retry_exceptions as e:
logger.info(f"Backoff exception: {e}")
print(f"[BACKOFF ERROR] {type(e).__name__}: {e}")
return False


Expand Down
8 changes: 4 additions & 4 deletions ai_scientist/utils/token_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def add_tokens(
reasoning_tokens: int,
cached_tokens: int,
):
self.token_counts[model]["prompt"] += prompt_tokens
self.token_counts[model]["completion"] += completion_tokens
self.token_counts[model]["reasoning"] += reasoning_tokens
self.token_counts[model]["cached"] += cached_tokens
self.token_counts[model]["prompt"] += prompt_tokens or 0
self.token_counts[model]["completion"] += completion_tokens or 0
self.token_counts[model]["reasoning"] += reasoning_tokens or 0
self.token_counts[model]["cached"] += cached_tokens or 0

def add_interaction(
self,
Expand Down
18 changes: 18 additions & 0 deletions ai_scientist/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
MAX_NUM_TOKENS = 4096

AVAILABLE_VLMS = [
"gpt-4o",
"gpt-4o-mini",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
Expand Down Expand Up @@ -144,6 +146,9 @@ def get_response_from_vlm(
if msg_history is None:
msg_history = []

if model.startswith("gpt-4.1"):
model = "gpt-4o"

if model in AVAILABLE_VLMS:
# Convert single image path to list for consistent handling
if isinstance(image_paths, str):
Expand Down Expand Up @@ -194,14 +199,27 @@ def get_response_from_vlm(

def create_client(model: str) -> tuple[Any, str]:
"""Create client for vision-language model."""
if model.startswith("gpt-4.1"):
fallback_model = "gpt-4o"
print(f"Model {model} not supported for VLM. Falling back to {fallback_model}.")
model = fallback_model
if model in [
"gpt-4o",
"gpt-4o-mini",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"gpt-4o-mini-2024-07-18",
"o3-mini",
]:
print(f"Using OpenAI API with model {model}.")
if os.environ.get("COPILOT_OAUTH_TOKEN"):
from ai_scientist.copilot_auth import get_copilot_jwt, copilot_default_headers
return openai.OpenAI(
api_key=get_copilot_jwt(),
base_url="https://api.individual.githubcopilot.com",
default_headers=copilot_default_headers(),
), model
return openai.OpenAI(), model
elif model.startswith("ollama/"):
print(f"Using Ollama API with model {model}.")
Expand Down
Loading