Skip to content

Conversation

eliasjudin
Copy link

Add support for Google Gemini API to aisuite and provide a guide.

  • New Provider Implementation

    • Add GoogleGenaiProvider class in aisuite/providers/google_genai_provider.py to handle Gemini API calls.
    • Implement chat_completions_create, generate_content, list_models, and normalize_response methods.
    • Handle authentication and API key management.
  • Provider Factory Update

    • Update ProviderFactory in aisuite/provider.py to include GoogleGenaiProvider.
  • Documentation

    • Add guides/google_genai.md with instructions for setting up and using the Gemini API with aisuite.
    • Update README.md to include the Gemini API as a supported provider and provide a brief example of how to use it.
  • Dependencies

    • Add google-genai to dependencies in pyproject.toml.

Add support for Google Gemini API to `aisuite` and provide a guide.

* **New Provider Implementation**
  - Add `GoogleGenaiProvider` class in `aisuite/providers/google_genai_provider.py` to handle Gemini API calls.
  - Implement `chat_completions_create`, `generate_content`, `list_models`, and `normalize_response` methods.
  - Handle authentication and API key management.

* **Provider Factory Update**
  - Update `ProviderFactory` in `aisuite/provider.py` to include `GoogleGenaiProvider`.

* **Documentation**
  - Add `guides/google_genai.md` with instructions for setting up and using the Gemini API with `aisuite`.
  - Update `README.md` to include the Gemini API as a supported provider and provide a brief example of how to use it.

* **Dependencies**
  - Add `google-genai` to dependencies in `pyproject.toml`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some issues with the parsing of the name so I just changed the name to Ggenai for the class and the file.
Also the genai api doesn't accept the temperature so you must change
``**kwargs to config=types.GenerateContentConfig(**kwargs)```
in generate_content and chat_completions_create.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll take a look. Appreciate any edits as I’m not too experienced integrating apis.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could i help on this ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll get this tomorrow

Copy link

@vargacypher vargacypher Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should keep the logic in the same provider, but making it dynamic and configurable to run using main vertex SDK OR with this new genai SDK.

What do u think ??

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mixing Vertex and genai clients in one provider complicates the class a little bit. Why not have separate clients, just like google has separate clients?

Hoho, just I was writing this I saw that google just released a unified API for GenAI and Vertex for javascript: https://github.com/googleapis/js-genai/tree/main.

🤷🏼‍♂️ . I still like the idea of seeing both implemented side by side before deciding if it makes sense unifying them.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackbravo The changes that i suggest convered the usage of new GenAI sdk both using genai or vertexai endpoints.
Just need to set on config vertex_options:True

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this branch still an active concern?

Comment on lines +1 to +49
import os
from google import genai
from google.genai import types
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class GoogleGenaiProvider(Provider):
def __init__(self, **config):
self.api_key = config.get("api_key") or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"Gemini API key is missing. Please provide it in the config or set the GEMINI_API_KEY environment variable."
)
self.client = genai.Client(api_key=self.api_key)

def chat_completions_create(self, model, messages, **kwargs):
try:
response = self.client.models.generate_content(
model=model,
contents=[message["content"] for message in messages],
**kwargs
)
return self.normalize_response(response)
except Exception as e:
raise LLMError(f"Error in chat_completions_create: {str(e)}")

def generate_content(self, model, contents, **kwargs):
try:
response = self.client.models.generate_content(
model=model,
contents=contents,
**kwargs
)
return self.normalize_response(response)
except Exception as e:
raise LLMError(f"Error in generate_content: {str(e)}")

def list_models(self):
try:
response = self.client.models.list()
return [model.name for model in response]
except Exception as e:
raise LLMError(f"Error in list_models: {str(e)}")

def normalize_response(self, response):
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.text
return normalized_response
Copy link

@vargacypher vargacypher Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import os
from google import genai
from google.genai import types
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse
class GoogleGenaiProvider(Provider):
def __init__(self, **config):
self.api_key = config.get("api_key") or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"Gemini API key is missing. Please provide it in the config or set the GEMINI_API_KEY environment variable."
)
self.client = genai.Client(api_key=self.api_key)
def chat_completions_create(self, model, messages, **kwargs):
try:
response = self.client.models.generate_content(
model=model,
contents=[message["content"] for message in messages],
**kwargs
)
return self.normalize_response(response)
except Exception as e:
raise LLMError(f"Error in chat_completions_create: {str(e)}")
def generate_content(self, model, contents, **kwargs):
try:
response = self.client.models.generate_content(
model=model,
contents=contents,
**kwargs
)
return self.normalize_response(response)
except Exception as e:
raise LLMError(f"Error in generate_content: {str(e)}")
def list_models(self):
try:
response = self.client.models.list()
return [model.name for model in response]
except Exception as e:
raise LLMError(f"Error in list_models: {str(e)}")
def normalize_response(self, response):
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.text
return normalized_response
"""The interface to Google's Genai."""
import os
import json
from typing import List, Dict, Any, Optional
from google import genai
from google.genai import types
from google.genai.types import Content, Part, Tool, FunctionDeclaration
import pprint
from aisuite.framework import ChatCompletionResponse, Message
from aisuite.provider import LLMError, Provider
DEFAULT_TEMPERATURE = 0.7
ENABLE_DEBUG_MESSAGES = False
class GoogleMessageConverter:
@staticmethod
def convert_user_role_message(message: Dict[str, Any]) -> Content:
"""Convert user or system messages to Google Vertex AI format."""
parts = [Part.from_text(text=message["content"])]
return Content(role="user", parts=parts)
@staticmethod
def convert_assistant_role_message(message: Dict[str, Any]) -> Content:
"""Convert assistant messages to Google Vertex AI format."""
if "tool_calls" in message and message["tool_calls"]:
# Handle function calls
tool_call = message["tool_calls"][
0
] # Assuming single function call for now
function_call = tool_call["function"]
# Create a Part from the function call
parts = [
Part.from_function_call(
name=function_call["name"],
args={},
# arguments = json.loads(function_call["arguments"])
)
]
# return Content(role="function", parts=parts)
else:
# Handle regular text messages
parts = [Part.from_text(text=message["content"])]
# return Content(role="model", parts=parts)
return Content(role="model", parts=parts)
@staticmethod
def convert_tool_role_message(message: Dict[str, Any]) -> Part:
"""Convert tool messages to Google Vertex AI format."""
if "content" not in message:
raise ValueError("Tool result message must have a content field")
try:
content_json = json.loads(message["content"])
part = Part.from_function_response(
name=message["name"], response=content_json
)
# TODO: Return Content instead of Part. But returning Content is not working.
return part
except json.JSONDecodeError:
raise ValueError("Tool result message must be valid JSON")
@staticmethod
def convert_request(messages: List[Dict[str, Any]]) -> List[Content]:
"""Convert messages to Google Vertex AI format."""
# Convert all messages to dicts if they're Message objects
messages = [
message.model_dump() if hasattr(message, "model_dump") else message
for message in messages
]
formatted_messages = []
for message in messages:
if message["role"] == "tool":
vertex_message = GoogleMessageConverter.convert_tool_role_message(
message
)
if vertex_message:
formatted_messages.append(vertex_message)
elif message["role"] == "assistant":
formatted_messages.append(
GoogleMessageConverter.convert_assistant_role_message(message)
)
else: # user or system role
formatted_messages.append(
GoogleMessageConverter.convert_user_role_message(message)
)
return formatted_messages
@staticmethod
def convert_response(response) -> ChatCompletionResponse:
"""Normalize the response from Vertex AI to match OpenAI's response format."""
openai_response = ChatCompletionResponse()
if ENABLE_DEBUG_MESSAGES:
print("Dumping the response")
pprint.pprint(response)
# TODO: We need to go through each part, because function call may not be the first part.
# Currently, we are only handling the first part, but this is not enough.
#
# This is a valid response:
# candidates {
# content {
# role: "model"
# parts {
# text: "The current temperature in San Francisco is 72 degrees Celsius. \n\n"
# }
# parts {
# function_call {
# name: "is_it_raining"
# args {
# fields {
# key: "location"
# value {
# string_value: "San Francisco"
# }
# }
# }
# }
# }
# }
# finish_reason: STOP
# Check if the response contains function calls
# Note: Just checking if the function_call attribute exists is not enough,
# it is important to check if the function_call is not None.
if (
hasattr(response.candidates[0].content.parts[0], "function_call")
and response.candidates[0].content.parts[0].function_call
):
function_call = response.candidates[0].content.parts[0].function_call
# args is a MapComposite.
# Convert the MapComposite to a dictionary
args_dict = {}
# Another way to try is: args_dict = dict(function_call.args)
for key, value in function_call.args.items():
args_dict[key] = value
if ENABLE_DEBUG_MESSAGES:
print("Dumping the args_dict")
pprint.pprint(args_dict)
openai_response.choices[0].message = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"type": "function",
"id": f"call_{hash(function_call.name)}", # Generate a unique ID
"function": {
"name": function_call.name,
"arguments": json.dumps(args_dict),
},
}
],
"refusal": None,
}
openai_response.choices[0].message = Message(
**openai_response.choices[0].message
)
openai_response.choices[0].finish_reason = "tool_calls"
else:
# Handle regular text response
openai_response.choices[0].message.content = (
response.candidates[0].content.parts[0].text
)
openai_response.choices[0].finish_reason = "stop"
return openai_response
class GooglegenaiProvider(Provider):
"""Implements the Provider Interface for interacting with Google's Generative AI."""
def __init__(self, **config):
"""Set up the Google AI client with a project ID."""
self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID")
self.location = config.get("region") or os.getenv("GOOGLE_REGION")
self.app_creds_path = config.get("application_credentials") or os.getenv(
"GOOGLE_APPLICATION_CREDENTIALS"
)
self.api_key = config.get("api_key") or os.getenv("GEMINI_API_KEY")
self.vertexai_option = config.get("vertexai_option")
if self.vertexai_option and (
not self.project_id or not self.location or not self.app_creds_path
):
raise EnvironmentError(
"Missing one or more required Google environment variables: "
"GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. "
"Please refer to the setup guide: /guides/google.md."
)
elif not self.vertexai_option and not self.api_key:
raise EnvironmentError(
"Missing required Google environment variable: GEMINI_API_KEY. "
"Please refer to the setup guide: /guides/google.md."
)
self.client = genai.Client(
project=self.project_id,
location=self.location,
vertexai=self.vertexai_option,
api_key=self.api_key,
)
self.transformer = GoogleMessageConverter()
def chat_completions_create(self, model, messages, **kwargs):
"""Request chat completions from the Google AI API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
kwargs (dict): Optional arguments for the Google AI API.
Returns:
-------
The ChatCompletionResponse with the completion result.
"""
# Set the temperature if provided, otherwise use the default
temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE)
# Set safety_settings if provided
safety_settings = kwargs.get("safety_settings")
# Convert messages to Vertex AI format
message_history = self.transformer.convert_request(messages)
# Handle tools if provided
tools = None
if "tools" in kwargs:
tools = [
Tool(
function_declarations=[
FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters={
"type": "object",
"properties": {
param_name: {
"type": param_info.get("type", "string"),
"description": param_info.get(
"description", ""
),
**(
{"enum": param_info["enum"]}
if "enum" in param_info
else {}
),
}
for param_name, param_info in tool["function"][
"parameters"
]["properties"].items()
},
"required": tool["function"]["parameters"].get(
"required", []
),
},
)
for tool in kwargs["tools"]
]
)
]
if ENABLE_DEBUG_MESSAGES:
print("Dumping the message_history")
pprint.pprint(message_history)
# Start chat and get response
chat = self.client.chats.create(
model=model,
history=message_history[:-1],
config=types.GenerateContentConfig(
tools=tools, temperature=temperature, safety_settings=safety_settings
),
)
last_message = message_history[-1]
# If the last message is a function response, send the Part object directly
# Otherwise, send just the text content
message_to_send = (
Content(role="function", parts=[last_message])
if isinstance(last_message, Part)
else last_message.parts[0].text
)
# response = chat.send_message(message_to_send)
response = chat.send_message(message_to_send)
# Convert and return the response
return self.transformer.convert_response(response)
def list_models(self):
try:
response = self.client.models.list()
return [model.name for model in response]
except Exception as e:
raise LLMError(f"Error in list_models: {str(e)}")

@vargacypher
Copy link

vargacypher commented Mar 11, 2025

I suggest some changes @eliasjudin.

You should change the file name from google_genai_provider.py to googlegenai_provider.py

import aisuite as ai
client = ai.Client()

provider = "google_genai"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
provider = "google_genai"
provider = "googlegenai"

Copy link
Collaborator

@rohitprasad15 rohitprasad15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this PR. Gemini-API will be an important addition as its easier to onboard.

I don't see handling of tool calls in the code. Please add that.
Please test the tool calling part, and paste screenshot that an example is working.
Before re-using the tool transformation code from Vertex provider as-is, please check if it will work for the google-genai pacakge.


### Prerequisites

1. **Google Cloud Account**: Ensure you have a Google Cloud account. If not, create one at [Google Cloud](https://cloud.google.com/).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need a Google Cloud Account to use the Gemini API ?
https://ai.google.dev/gemini-api/docs/migrate-to-cloud#gemini-developer-api

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants