From a19f44faa51986ab7a2627ff012c146ab56266fc Mon Sep 17 00:00:00 2001 From: Sheing Ng Date: Thu, 10 Jul 2025 01:46:37 -0500 Subject: [PATCH] feature: xAI support. --- api/config.py | 9 +- api/config/generator.json | 19 +++ api/requirements.txt | 1 + api/websocket_wiki.py | 34 +++++ api/xai_client.py | 294 ++++++++++++++++++++++++++++++++++++++ src/messages/en.json | 1 + src/messages/es.json | 1 + src/messages/ja.json | 1 + src/messages/kr.json | 1 + src/messages/pt-br.json | 1 + src/messages/vi.json | 1 + src/messages/zh-tw.json | 1 + src/messages/zh.json | 1 + 13 files changed, 362 insertions(+), 3 deletions(-) create mode 100644 api/xai_client.py diff --git a/api/config.py b/api/config.py index 58bc5e32..662c20ec 100644 --- a/api/config.py +++ b/api/config.py @@ -11,6 +11,7 @@ from api.openrouter_client import OpenRouterClient from api.bedrock_client import BedrockClient from api.azureai_client import AzureAIClient +from api.xai_client import XAIClient from adalflow import GoogleGenAIClient, OllamaClient # Get API keys from environment variables @@ -53,7 +54,8 @@ "OpenRouterClient": OpenRouterClient, "OllamaClient": OllamaClient, "BedrockClient": BedrockClient, - "AzureAIClient": AzureAIClient + "AzureAIClient": AzureAIClient, + "XAIClient": XAIClient } def replace_env_placeholders(config: Union[Dict[str, Any], List[Any], str, Any]) -> Union[Dict[str, Any], List[Any], str, Any]: @@ -121,14 +123,15 @@ def load_generator_config(): if provider_config.get("client_class") in CLIENT_CLASSES: provider_config["model_client"] = CLIENT_CLASSES[provider_config["client_class"]] # Fall back to default mapping based on provider_id - elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure"]: + elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "xai"]: default_map = { "google": GoogleGenAIClient, "openai": OpenAIClient, "openrouter": OpenRouterClient, "ollama": OllamaClient, "bedrock": BedrockClient, - "azure": AzureAIClient + "azure": AzureAIClient, + "xai": XAIClient } provider_config["model_client"] = default_map[provider_id] else: diff --git a/api/config/generator.json b/api/config/generator.json index 9047d09b..81e50a6a 100644 --- a/api/config/generator.json +++ b/api/config/generator.json @@ -164,6 +164,25 @@ "top_p": 0.8 } } + }, + "xai": { + "client_class": "XAIClient", + "default_model": "grok-4-0709", + "supportsCustomModel": true, + "models": { + "grok-4-0709": { + "temperature": 0.7 + }, + "grok-3": { + "temperature": 0.7 + }, + "grok-2": { + "temperature": 0.7 + }, + "grok-beta": { + "temperature": 0.7 + } + } } } } diff --git a/api/requirements.txt b/api/requirements.txt index 2a069561..13a632d6 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -17,4 +17,5 @@ boto3>=1.34.0 websockets>=11.0.3 azure-identity>=1.12.0 azure-core>=1.24.0 +xai-sdk>=0.1.0 diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py index c8292996..8fb6834f 100644 --- a/api/websocket_wiki.py +++ b/api/websocket_wiki.py @@ -14,6 +14,7 @@ from api.openai_client import OpenAIClient from api.openrouter_client import OpenRouterClient from api.azureai_client import AzureAIClient +from api.xai_client import XAIClient from api.rag import RAG # Configure logging @@ -505,6 +506,22 @@ async def handle_websocket_chat(websocket: WebSocket): "top_p": model_config["top_p"] } + api_kwargs = model.convert_inputs_to_api_kwargs( + input=prompt, + model_kwargs=model_kwargs, + model_type=ModelType.LLM + ) + elif request.provider == "xai": + logger.info(f"Using xAI with model: {request.model}") + + # Initialize xAI client + model = XAIClient() + model_kwargs = { + "model": request.model, + "stream": True, + "temperature": model_config["temperature"] + } + api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, @@ -594,6 +611,23 @@ async def handle_websocket_chat(websocket: WebSocket): await websocket.send_text(error_msg) # Close the WebSocket connection after sending the error message await websocket.close() + elif request.provider == "xai": + try: + # Get the response and handle it properly using the previously created api_kwargs + logger.info("Making xAI API call") + response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) + # Handle streaming response from xAI + async for chunk in response: + if chunk: # xAI returns text chunks directly + await websocket.send_text(chunk) + # Explicitly close the WebSocket connection after the response is complete + await websocket.close() + except Exception as e_xai: + logger.error(f"Error with xAI API: {str(e_xai)}") + error_msg = f"\nError with xAI API: {str(e_xai)}\n\nPlease check that you have set the XAI_API_KEY environment variable with a valid API key." + await websocket.send_text(error_msg) + # Close the WebSocket connection after sending the error message + await websocket.close() else: # Generate streaming response response = model.generate_content(prompt, stream=True) diff --git a/api/xai_client.py b/api/xai_client.py new file mode 100644 index 00000000..a68e75bf --- /dev/null +++ b/api/xai_client.py @@ -0,0 +1,294 @@ +"""xAI ModelClient integration.""" + +import os +import logging +import asyncio +from typing import Dict, Optional, Any, Callable, Literal, List +import backoff + +from adalflow.core.model_client import ModelClient +from adalflow.core.types import ModelType, GeneratorOutput + +log = logging.getLogger(__name__) + +def get_first_message_content(response) -> str: + """Extract content from xAI response.""" + if hasattr(response, 'content'): + return response.content + return str(response) + +def handle_streaming_response(response): + """Handle streaming response from xAI API.""" + try: + # For streaming responses, we need to collect all chunks + collected_content = "" + for chunk in response: + if hasattr(chunk, 'content') and chunk.content: + collected_content += chunk.content + yield chunk.content + + # Return the final collected content + if collected_content: + return collected_content + except Exception as e: + log.error(f"Error handling streaming response: {e}") + yield f"Error: {str(e)}" + +class XAIClient(ModelClient): + __doc__ = r"""A component wrapper for the xAI API client. + + Supports chat completion APIs using xAI's Grok models. + + Users can: + 1. Simplify use of ``Generator`` components by passing `XAIClient()` as the `model_client`. + 2. Use this as a reference to create their own API client or extend this class by copying and modifying the code. + + To use xAI API, you need to set the XAI_API_KEY environment variable. + + Example: + ```python + from api.xai_client import XAIClient + import adalflow as adal + + client = XAIClient() + generator = adal.Generator( + model_client=client, + model_kwargs={"model": "grok-4-0709"} + ) + ``` + + References: + - xAI API Documentation: https://docs.x.ai/ + - xAI SDK: https://github.com/xai-org/xai-sdk + """ + + def __init__( + self, + api_key: Optional[str] = None, + api_host: Optional[str] = None, + chat_completion_parser: Callable = None, + input_type: Literal["text", "messages"] = "text", + env_api_key_name: str = "XAI_API_KEY", + env_api_host_name: str = "XAI_API_HOST", + ): + r"""Initialize the xAI client. + + Args: + api_key (Optional[str], optional): xAI API key. Defaults to None. + api_host (Optional[str], optional): xAI API host. Defaults to "api.x.ai". + chat_completion_parser: Function to parse chat completions. + input_type: Input format, either "text" or "messages". + env_api_key_name (str): The environment variable name for the API key. + env_api_host_name (str): The environment variable name for the API host. + """ + super().__init__() + self._api_key = api_key + self._env_api_key_name = env_api_key_name + self._env_api_host_name = env_api_host_name + self.api_host = api_host or os.getenv(self._env_api_host_name, "api.x.ai") + self.sync_client = self.init_sync_client() + self.async_client = None # only initialize if the async call is called + self.chat_completion_parser = ( + chat_completion_parser or get_first_message_content + ) + self._input_type = input_type + + def init_sync_client(self): + """Initialize the synchronous xAI client.""" + try: + from xai_sdk import Client + except ImportError: + raise ImportError( + "xai_sdk is required to use XAIClient. Install it with: pip install xai-sdk" + ) + + api_key = self._api_key or os.getenv(self._env_api_key_name) + if not api_key: + log.warning("XAI_API_KEY not configured") + # Return a dummy client that will fail gracefully when used + return None + + return Client( + api_host=self.api_host, + api_key=api_key + ) + + def init_async_client(self): + """Initialize the asynchronous xAI client.""" + # For now, we'll use the sync client for async operations + # This can be improved when xAI SDK provides native async support + return self.init_sync_client() + + def convert_inputs_to_api_kwargs( + self, input: Any, model_kwargs: Dict = None, model_type: ModelType = None + ) -> Dict: + """Convert AdalFlow inputs to xAI API format.""" + model_kwargs = model_kwargs or {} + + if model_type == ModelType.LLM: + # Handle different input types + if self._input_type == "messages": + if isinstance(input, list): + messages = input + else: + messages = [{"role": "user", "content": str(input)}] + else: + # Convert text input to messages format + if isinstance(input, str): + messages = [{"role": "user", "content": input}] + else: + messages = [{"role": "user", "content": str(input)}] + + # Prepare API kwargs + api_kwargs = { + "messages": messages, + **model_kwargs + } + + return api_kwargs + else: + raise ValueError(f"model_type {model_type} is not supported by XAIClient") + + def parse_chat_completion(self, response) -> GeneratorOutput: + """Parse the chat completion response into a GeneratorOutput.""" + try: + if hasattr(response, 'content'): + # Direct response with content + return GeneratorOutput( + data=response.content, + raw_response=str(response), + ) + else: + # Handle other response formats + return GeneratorOutput( + data=str(response), + raw_response=str(response), + ) + except Exception as e: + log.error(f"Error parsing chat completion response: {e}") + return GeneratorOutput(data=None, error=str(e), raw_response=str(response)) + + @backoff.on_exception( + backoff.expo, + (Exception,), # xAI SDK might have specific exceptions, but we'll catch all for now + max_time=5, + ) + def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + """ + Make a synchronous call to the xAI API. + """ + log.info(f"api_kwargs: {api_kwargs}") + self._api_kwargs = api_kwargs + + if model_type == ModelType.LLM: + # Check if client is properly initialized + if not self.sync_client: + raise ValueError("XAI client not properly initialized. Please set XAI_API_KEY environment variable.") + + try: + from xai_sdk.chat import user, system + + # Create a new chat instance + chat = self.sync_client.chat.create( + model=api_kwargs.get("model", "grok-4-0709"), + temperature=api_kwargs.get("temperature", 0.7) + ) + + # Add messages to the chat + messages = api_kwargs.get("messages", []) + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + if role == "system": + chat.append(system(content)) + else: # user or assistant + chat.append(user(content)) + + # Get the response + response = chat.sample() + + # Handle streaming if requested + if api_kwargs.get("stream", False): + # For streaming, we'll simulate by yielding the content + async def async_stream_generator(): + yield response.content + return async_stream_generator() + else: + return response + + except Exception as e: + log.error(f"Error in xAI API call: {e}") + raise + else: + raise ValueError(f"model_type {model_type} is not supported by XAIClient") + + async def acall( + self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED + ): + """ + Make an asynchronous call to the xAI API. + """ + # Check if client is properly initialized + if not self.sync_client: + raise ValueError("XAI client not properly initialized. Please set XAI_API_KEY environment variable.") + + if model_type == ModelType.LLM: + try: + from xai_sdk.chat import user, system + + # Create a new chat instance in a thread pool + loop = asyncio.get_event_loop() + + def create_chat_and_get_response(): + chat = self.sync_client.chat.create( + model=api_kwargs.get("model", "grok-4-0709"), + temperature=api_kwargs.get("temperature", 0.7) + ) + + # Add messages to the chat + messages = api_kwargs.get("messages", []) + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + + if role == "system": + chat.append(system(content)) + else: # user or assistant + chat.append(user(content)) + + # Get the response + return chat.sample() + + response = await loop.run_in_executor(None, create_chat_and_get_response) + + # Handle streaming if requested + if api_kwargs.get("stream", False): + # For streaming, we'll simulate by yielding the content + async def async_stream_generator(): + yield response.content + return async_stream_generator() + else: + return response + + except Exception as e: + log.error(f"Error in xAI API call: {e}") + raise + else: + raise ValueError(f"model_type {model_type} is not supported by XAIClient") + + +# Example usage: +if __name__ == "__main__": + from adalflow.core import Generator + from adalflow.utils import setup_env + + setup_env() + prompt_kwargs = {"input_str": "What is the meaning of life?"} + + gen = Generator( + model_client=XAIClient(), + model_kwargs={"model": "grok-4-0709", "stream": False}, + ) + gen_response = gen(prompt_kwargs) + print(f"gen_response: {gen_response}") diff --git a/src/messages/en.json b/src/messages/en.json index a62ee892..ba87ca03 100644 --- a/src/messages/en.json +++ b/src/messages/en.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama (Local)", + "providerXai": "xAI", "localOllama": "Local Ollama Model", "experimental": "Experimental", "useOpenRouter": "Use OpenRouter API", diff --git a/src/messages/es.json b/src/messages/es.json index cace0b54..d2254178 100644 --- a/src/messages/es.json +++ b/src/messages/es.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama (Local)", + "providerXai": "xAI", "localOllama": "Modelo Ollama Local", "experimental": "Experimental", "useOpenRouter": "Usar API de OpenRouter", diff --git a/src/messages/ja.json b/src/messages/ja.json index 8692635f..0f6ca046 100644 --- a/src/messages/ja.json +++ b/src/messages/ja.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama(ローカル)", + "providerXai": "xAI", "localOllama": "ローカルOllamaモデル", "experimental": "実験的", "useOpenRouter": "OpenRouter APIを使用", diff --git a/src/messages/kr.json b/src/messages/kr.json index 68666f3d..f52b3413 100644 --- a/src/messages/kr.json +++ b/src/messages/kr.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama (로컬)", + "providerXai": "xAI", "localOllama": "로컬 Ollama 모델", "experimental": "실험적", "useOpenRouter": "OpenRouter API 사용", diff --git a/src/messages/pt-br.json b/src/messages/pt-br.json index 3bb05575..39ecae11 100644 --- a/src/messages/pt-br.json +++ b/src/messages/pt-br.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama (Local)", + "providerXai": "xAI", "localOllama": "Modelo Ollama Local", "experimental": "Experimental", "useOpenRouter": "Usar API OpenRouter", diff --git a/src/messages/vi.json b/src/messages/vi.json index 7ac1933c..2450bb13 100644 --- a/src/messages/vi.json +++ b/src/messages/vi.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama (Cục bộ)", + "providerXai": "xAI", "localOllama": "Mô hình Ollama cục bộ", "experimental": "Thử nghiệm", "useOpenRouter": "Sử dụng API OpenRouter", diff --git a/src/messages/zh-tw.json b/src/messages/zh-tw.json index 67a1bd2d..f8822900 100644 --- a/src/messages/zh-tw.json +++ b/src/messages/zh-tw.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama(本機)", + "providerXai": "xAI", "localOllama": "本機 Ollama 模型", "experimental": "實驗性", "useOpenRouter": "使用 OpenRouter API", diff --git a/src/messages/zh.json b/src/messages/zh.json index 10fc7b22..59024db4 100644 --- a/src/messages/zh.json +++ b/src/messages/zh.json @@ -45,6 +45,7 @@ "providerOpenAI": "OpenAI", "providerOpenRouter": "OpenRouter", "providerOllama": "Ollama (本地)", + "providerXai": "xAI", "localOllama": "本地Ollama模型", "experimental": "实验性", "useOpenRouter": "使用OpenRouter API",