diff --git a/README.md b/README.md index 525e8f51..b086ed7d 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ pip install -r requirements.txt #### OpenAI Models By default, the system uses the `OPENAI_API_KEY` environment variable for OpenAI models. +Additionally, if you are using a custom OpenAI-compatible API endpoint (e.g., a local LLM server), you can specify its URL using the `BASE_URL` environment variable. #### Gemini Models @@ -85,6 +86,10 @@ Our code can optionally use a Semantic Scholar API Key (`S2_API_KEY`) for higher Ensure you provide the necessary API keys as environment variables for the models you intend to use. For example: ```bash export OPENAI_API_KEY="YOUR_OPENAI_KEY_HERE" +# Optional: For custom OpenAI-compatible endpoints +# export BASE_URL="YOUR_CUSTOM_API_BASE_URL" + +# For Semantic Scholar API (Optional) export S2_API_KEY="YOUR_S2_KEY_HERE" # Set AWS credentials if using Bedrock # export AWS_ACCESS_KEY_ID="YOUR_AWS_ACCESS_KEY_ID" diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py index afd85a41..ba1fde90 100644 --- a/ai_scientist/llm.py +++ b/ai_scientist/llm.py @@ -7,6 +7,7 @@ import anthropic import backoff import openai +import os MAX_NUM_TOKENS = 4096 @@ -426,12 +427,10 @@ def create_client(model) -> tuple[Any, str]: client_model = model.split("/")[-1] print(f"Using Vertex AI with model {client_model}.") return anthropic.AnthropicVertex(), client_model - elif "gpt" in model: - print(f"Using OpenAI API with model {model}.") - return openai.OpenAI(), model - elif "o1" in model or "o3" in model: + elif "gpt" in model or "o1" in model or "o3" in model: print(f"Using OpenAI API with model {model}.") - return openai.OpenAI(), model + base_url = os.getenv('BASE_URL') + return openai.OpenAI(base_url=base_url if base_url else None), model elif model == "deepseek-coder-v2-0724": print(f"Using OpenAI API with {model}.") return ( diff --git a/ai_scientist/treesearch/backend/backend_openai.py b/ai_scientist/treesearch/backend/backend_openai.py index ae318ec4..6ac0721c 100644 --- a/ai_scientist/treesearch/backend/backend_openai.py +++ b/ai_scientist/treesearch/backend/backend_openai.py @@ -22,7 +22,9 @@ @once def _setup_openai_client(): global _client - _client = openai.OpenAI(max_retries=0) + import os + base_url = os.getenv('BASE_URL') + _client = openai.OpenAI(max_retries=0, base_url=base_url if base_url else None) def query( diff --git a/ai_scientist/treesearch/log_summarization.py b/ai_scientist/treesearch/log_summarization.py index 436cdd18..2aa3773b 100644 --- a/ai_scientist/treesearch/log_summarization.py +++ b/ai_scientist/treesearch/log_summarization.py @@ -8,9 +8,11 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) sys.path.insert(0, parent_dir) +import os from ai_scientist.llm import get_response_from_llm, extract_json_between_markers -client = openai.OpenAI() +base_url = os.getenv('BASE_URL') +client = openai.OpenAI(base_url=base_url if base_url else None) model = "gpt-4o-2024-08-06" report_summarizer_sys_msg = """You are an expert machine learning researcher. diff --git a/ai_scientist/vlm.py b/ai_scientist/vlm.py index 240015eb..5410cc2a 100644 --- a/ai_scientist/vlm.py +++ b/ai_scientist/vlm.py @@ -165,7 +165,9 @@ def create_client(model: str) -> tuple[Any, str]: "o3-mini", ]: print(f"Using OpenAI API with model {model}.") - return openai.OpenAI(), model + import os + base_url = os.getenv('BASE_URL') + return openai.OpenAI(base_url=base_url if base_url else None), model else: raise ValueError(f"Model {model} not supported.")