-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add Google Gemini API support to aisuite #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add support for Google Gemini API to `aisuite` and provide a guide. * **New Provider Implementation** - Add `GoogleGenaiProvider` class in `aisuite/providers/google_genai_provider.py` to handle Gemini API calls. - Implement `chat_completions_create`, `generate_content`, `list_models`, and `normalize_response` methods. - Handle authentication and API key management. * **Provider Factory Update** - Update `ProviderFactory` in `aisuite/provider.py` to include `GoogleGenaiProvider`. * **Documentation** - Add `guides/google_genai.md` with instructions for setting up and using the Gemini API with `aisuite`. - Update `README.md` to include the Gemini API as a supported provider and provide a brief example of how to use it. * **Dependencies** - Add `google-genai` to dependencies in `pyproject.toml`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had some issues with the parsing of the name so I just changed the name to Ggenai for the class and the file.
Also the genai api doesn't accept the temperature so you must change
``**kwargs to config=types.GenerateContentConfig(**kwargs)```
in generate_content and chat_completions_create.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ll take a look. Appreciate any edits as I’m not too experienced integrating apis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could i help on this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll get this tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we should keep the logic in the same provider, but making it dynamic and configurable to run using main vertex SDK OR with this new genai SDK.
What do u think ??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think mixing Vertex and genai clients in one provider complicates the class a little bit. Why not have separate clients, just like google has separate clients?
Hoho, just I was writing this I saw that google just released a unified API for GenAI and Vertex for javascript: https://github.com/googleapis/js-genai/tree/main.
🤷🏼♂️ . I still like the idea of seeing both implemented side by side before deciding if it makes sense unifying them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jackbravo The changes that i suggest convered the usage of new GenAI sdk both using genai or vertexai endpoints.
Just need to set on config vertex_options:True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this branch still an active concern?
import os | ||
from google import genai | ||
from google.genai import types | ||
from aisuite.provider import Provider, LLMError | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
|
||
class GoogleGenaiProvider(Provider): | ||
def __init__(self, **config): | ||
self.api_key = config.get("api_key") or os.getenv("GEMINI_API_KEY") | ||
if not self.api_key: | ||
raise ValueError( | ||
"Gemini API key is missing. Please provide it in the config or set the GEMINI_API_KEY environment variable." | ||
) | ||
self.client = genai.Client(api_key=self.api_key) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
try: | ||
response = self.client.models.generate_content( | ||
model=model, | ||
contents=[message["content"] for message in messages], | ||
**kwargs | ||
) | ||
return self.normalize_response(response) | ||
except Exception as e: | ||
raise LLMError(f"Error in chat_completions_create: {str(e)}") | ||
|
||
def generate_content(self, model, contents, **kwargs): | ||
try: | ||
response = self.client.models.generate_content( | ||
model=model, | ||
contents=contents, | ||
**kwargs | ||
) | ||
return self.normalize_response(response) | ||
except Exception as e: | ||
raise LLMError(f"Error in generate_content: {str(e)}") | ||
|
||
def list_models(self): | ||
try: | ||
response = self.client.models.list() | ||
return [model.name for model in response] | ||
except Exception as e: | ||
raise LLMError(f"Error in list_models: {str(e)}") | ||
|
||
def normalize_response(self, response): | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response.text | ||
return normalized_response |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import os | |
from google import genai | |
from google.genai import types | |
from aisuite.provider import Provider, LLMError | |
from aisuite.framework import ChatCompletionResponse | |
class GoogleGenaiProvider(Provider): | |
def __init__(self, **config): | |
self.api_key = config.get("api_key") or os.getenv("GEMINI_API_KEY") | |
if not self.api_key: | |
raise ValueError( | |
"Gemini API key is missing. Please provide it in the config or set the GEMINI_API_KEY environment variable." | |
) | |
self.client = genai.Client(api_key=self.api_key) | |
def chat_completions_create(self, model, messages, **kwargs): | |
try: | |
response = self.client.models.generate_content( | |
model=model, | |
contents=[message["content"] for message in messages], | |
**kwargs | |
) | |
return self.normalize_response(response) | |
except Exception as e: | |
raise LLMError(f"Error in chat_completions_create: {str(e)}") | |
def generate_content(self, model, contents, **kwargs): | |
try: | |
response = self.client.models.generate_content( | |
model=model, | |
contents=contents, | |
**kwargs | |
) | |
return self.normalize_response(response) | |
except Exception as e: | |
raise LLMError(f"Error in generate_content: {str(e)}") | |
def list_models(self): | |
try: | |
response = self.client.models.list() | |
return [model.name for model in response] | |
except Exception as e: | |
raise LLMError(f"Error in list_models: {str(e)}") | |
def normalize_response(self, response): | |
normalized_response = ChatCompletionResponse() | |
normalized_response.choices[0].message.content = response.text | |
return normalized_response | |
"""The interface to Google's Genai.""" | |
import os | |
import json | |
from typing import List, Dict, Any, Optional | |
from google import genai | |
from google.genai import types | |
from google.genai.types import Content, Part, Tool, FunctionDeclaration | |
import pprint | |
from aisuite.framework import ChatCompletionResponse, Message | |
from aisuite.provider import LLMError, Provider | |
DEFAULT_TEMPERATURE = 0.7 | |
ENABLE_DEBUG_MESSAGES = False | |
class GoogleMessageConverter: | |
@staticmethod | |
def convert_user_role_message(message: Dict[str, Any]) -> Content: | |
"""Convert user or system messages to Google Vertex AI format.""" | |
parts = [Part.from_text(text=message["content"])] | |
return Content(role="user", parts=parts) | |
@staticmethod | |
def convert_assistant_role_message(message: Dict[str, Any]) -> Content: | |
"""Convert assistant messages to Google Vertex AI format.""" | |
if "tool_calls" in message and message["tool_calls"]: | |
# Handle function calls | |
tool_call = message["tool_calls"][ | |
0 | |
] # Assuming single function call for now | |
function_call = tool_call["function"] | |
# Create a Part from the function call | |
parts = [ | |
Part.from_function_call( | |
name=function_call["name"], | |
args={}, | |
# arguments = json.loads(function_call["arguments"]) | |
) | |
] | |
# return Content(role="function", parts=parts) | |
else: | |
# Handle regular text messages | |
parts = [Part.from_text(text=message["content"])] | |
# return Content(role="model", parts=parts) | |
return Content(role="model", parts=parts) | |
@staticmethod | |
def convert_tool_role_message(message: Dict[str, Any]) -> Part: | |
"""Convert tool messages to Google Vertex AI format.""" | |
if "content" not in message: | |
raise ValueError("Tool result message must have a content field") | |
try: | |
content_json = json.loads(message["content"]) | |
part = Part.from_function_response( | |
name=message["name"], response=content_json | |
) | |
# TODO: Return Content instead of Part. But returning Content is not working. | |
return part | |
except json.JSONDecodeError: | |
raise ValueError("Tool result message must be valid JSON") | |
@staticmethod | |
def convert_request(messages: List[Dict[str, Any]]) -> List[Content]: | |
"""Convert messages to Google Vertex AI format.""" | |
# Convert all messages to dicts if they're Message objects | |
messages = [ | |
message.model_dump() if hasattr(message, "model_dump") else message | |
for message in messages | |
] | |
formatted_messages = [] | |
for message in messages: | |
if message["role"] == "tool": | |
vertex_message = GoogleMessageConverter.convert_tool_role_message( | |
message | |
) | |
if vertex_message: | |
formatted_messages.append(vertex_message) | |
elif message["role"] == "assistant": | |
formatted_messages.append( | |
GoogleMessageConverter.convert_assistant_role_message(message) | |
) | |
else: # user or system role | |
formatted_messages.append( | |
GoogleMessageConverter.convert_user_role_message(message) | |
) | |
return formatted_messages | |
@staticmethod | |
def convert_response(response) -> ChatCompletionResponse: | |
"""Normalize the response from Vertex AI to match OpenAI's response format.""" | |
openai_response = ChatCompletionResponse() | |
if ENABLE_DEBUG_MESSAGES: | |
print("Dumping the response") | |
pprint.pprint(response) | |
# TODO: We need to go through each part, because function call may not be the first part. | |
# Currently, we are only handling the first part, but this is not enough. | |
# | |
# This is a valid response: | |
# candidates { | |
# content { | |
# role: "model" | |
# parts { | |
# text: "The current temperature in San Francisco is 72 degrees Celsius. \n\n" | |
# } | |
# parts { | |
# function_call { | |
# name: "is_it_raining" | |
# args { | |
# fields { | |
# key: "location" | |
# value { | |
# string_value: "San Francisco" | |
# } | |
# } | |
# } | |
# } | |
# } | |
# } | |
# finish_reason: STOP | |
# Check if the response contains function calls | |
# Note: Just checking if the function_call attribute exists is not enough, | |
# it is important to check if the function_call is not None. | |
if ( | |
hasattr(response.candidates[0].content.parts[0], "function_call") | |
and response.candidates[0].content.parts[0].function_call | |
): | |
function_call = response.candidates[0].content.parts[0].function_call | |
# args is a MapComposite. | |
# Convert the MapComposite to a dictionary | |
args_dict = {} | |
# Another way to try is: args_dict = dict(function_call.args) | |
for key, value in function_call.args.items(): | |
args_dict[key] = value | |
if ENABLE_DEBUG_MESSAGES: | |
print("Dumping the args_dict") | |
pprint.pprint(args_dict) | |
openai_response.choices[0].message = { | |
"role": "assistant", | |
"content": None, | |
"tool_calls": [ | |
{ | |
"type": "function", | |
"id": f"call_{hash(function_call.name)}", # Generate a unique ID | |
"function": { | |
"name": function_call.name, | |
"arguments": json.dumps(args_dict), | |
}, | |
} | |
], | |
"refusal": None, | |
} | |
openai_response.choices[0].message = Message( | |
**openai_response.choices[0].message | |
) | |
openai_response.choices[0].finish_reason = "tool_calls" | |
else: | |
# Handle regular text response | |
openai_response.choices[0].message.content = ( | |
response.candidates[0].content.parts[0].text | |
) | |
openai_response.choices[0].finish_reason = "stop" | |
return openai_response | |
class GooglegenaiProvider(Provider): | |
"""Implements the Provider Interface for interacting with Google's Generative AI.""" | |
def __init__(self, **config): | |
"""Set up the Google AI client with a project ID.""" | |
self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID") | |
self.location = config.get("region") or os.getenv("GOOGLE_REGION") | |
self.app_creds_path = config.get("application_credentials") or os.getenv( | |
"GOOGLE_APPLICATION_CREDENTIALS" | |
) | |
self.api_key = config.get("api_key") or os.getenv("GEMINI_API_KEY") | |
self.vertexai_option = config.get("vertexai_option") | |
if self.vertexai_option and ( | |
not self.project_id or not self.location or not self.app_creds_path | |
): | |
raise EnvironmentError( | |
"Missing one or more required Google environment variables: " | |
"GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. " | |
"Please refer to the setup guide: /guides/google.md." | |
) | |
elif not self.vertexai_option and not self.api_key: | |
raise EnvironmentError( | |
"Missing required Google environment variable: GEMINI_API_KEY. " | |
"Please refer to the setup guide: /guides/google.md." | |
) | |
self.client = genai.Client( | |
project=self.project_id, | |
location=self.location, | |
vertexai=self.vertexai_option, | |
api_key=self.api_key, | |
) | |
self.transformer = GoogleMessageConverter() | |
def chat_completions_create(self, model, messages, **kwargs): | |
"""Request chat completions from the Google AI API. | |
Args: | |
---- | |
model (str): Identifies the specific provider/model to use. | |
messages (list of dict): A list of message objects in chat history. | |
kwargs (dict): Optional arguments for the Google AI API. | |
Returns: | |
------- | |
The ChatCompletionResponse with the completion result. | |
""" | |
# Set the temperature if provided, otherwise use the default | |
temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE) | |
# Set safety_settings if provided | |
safety_settings = kwargs.get("safety_settings") | |
# Convert messages to Vertex AI format | |
message_history = self.transformer.convert_request(messages) | |
# Handle tools if provided | |
tools = None | |
if "tools" in kwargs: | |
tools = [ | |
Tool( | |
function_declarations=[ | |
FunctionDeclaration( | |
name=tool["function"]["name"], | |
description=tool["function"].get("description", ""), | |
parameters={ | |
"type": "object", | |
"properties": { | |
param_name: { | |
"type": param_info.get("type", "string"), | |
"description": param_info.get( | |
"description", "" | |
), | |
**( | |
{"enum": param_info["enum"]} | |
if "enum" in param_info | |
else {} | |
), | |
} | |
for param_name, param_info in tool["function"][ | |
"parameters" | |
]["properties"].items() | |
}, | |
"required": tool["function"]["parameters"].get( | |
"required", [] | |
), | |
}, | |
) | |
for tool in kwargs["tools"] | |
] | |
) | |
] | |
if ENABLE_DEBUG_MESSAGES: | |
print("Dumping the message_history") | |
pprint.pprint(message_history) | |
# Start chat and get response | |
chat = self.client.chats.create( | |
model=model, | |
history=message_history[:-1], | |
config=types.GenerateContentConfig( | |
tools=tools, temperature=temperature, safety_settings=safety_settings | |
), | |
) | |
last_message = message_history[-1] | |
# If the last message is a function response, send the Part object directly | |
# Otherwise, send just the text content | |
message_to_send = ( | |
Content(role="function", parts=[last_message]) | |
if isinstance(last_message, Part) | |
else last_message.parts[0].text | |
) | |
# response = chat.send_message(message_to_send) | |
response = chat.send_message(message_to_send) | |
# Convert and return the response | |
return self.transformer.convert_response(response) | |
def list_models(self): | |
try: | |
response = self.client.models.list() | |
return [model.name for model in response] | |
except Exception as e: | |
raise LLMError(f"Error in list_models: {str(e)}") | |
I suggest some changes @eliasjudin. You should change the file name from google_genai_provider.py to googlegenai_provider.py |
import aisuite as ai | ||
client = ai.Client() | ||
|
||
provider = "google_genai" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
provider = "google_genai" | |
provider = "googlegenai" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for raising this PR. Gemini-API will be an important addition as its easier to onboard.
I don't see handling of tool calls in the code. Please add that.
Please test the tool calling part, and paste screenshot that an example is working.
Before re-using the tool transformation code from Vertex provider as-is, please check if it will work for the google-genai pacakge.
|
||
### Prerequisites | ||
|
||
1. **Google Cloud Account**: Ensure you have a Google Cloud account. If not, create one at [Google Cloud](https://cloud.google.com/). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need a Google Cloud Account to use the Gemini API ?
https://ai.google.dev/gemini-api/docs/migrate-to-cloud#gemini-developer-api
Add support for Google Gemini API to
aisuite
and provide a guide.New Provider Implementation
GoogleGenaiProvider
class inaisuite/providers/google_genai_provider.py
to handle Gemini API calls.chat_completions_create
,generate_content
,list_models
, andnormalize_response
methods.Provider Factory Update
ProviderFactory
inaisuite/provider.py
to includeGoogleGenaiProvider
.Documentation
guides/google_genai.md
with instructions for setting up and using the Gemini API withaisuite
.README.md
to include the Gemini API as a supported provider and provide a brief example of how to use it.Dependencies
google-genai
to dependencies inpyproject.toml
.