Skip to content

litellm Integration #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions os_computer_use/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
grounding_model = providers.OSAtlasProvider()
# grounding_model = providers.ShowUIProvider()

# vision_model = providers.FireworksProvider("llama3.2")
# vision_model = providers.OpenAIProvider("gpt-4o")
# vision_model = providers.AnthropicProvider("claude-3.5-sonnet")
vision_model = providers.GroqProvider("llama3.2")
# vision_model = providers.MistralProvider("pixtral") # pixtral-large-latest has vision capabilities

# Vision models using LiteLLM:
vision_model = providers.LiteLLMProvider("pixtral") # Mistral
# vision_model = providers.LiteLLMProvider("llama3.2", provider="fireworks") # Fireworks
# vision_model = providers.LiteLLMProvider("gpt-4-vision") # OpenAI
# vision_model = providers.LiteLLMProvider("llama3.2", provider="groq") # Groq
# vision_model = providers.LiteLLMProvider("claude-3-5-sonnet") # Anthropic
# vision_model = providers.LiteLLMProvider("gemini-2.0-flash", provider="gemini") # Gemini

# action_model = providers.FireworksProvider("llama3.3")
# action_model = providers.OpenAIProvider("gpt-4o")
# action_model = providers.AnthropicProvider("claude-3.5-sonnet")
action_model = providers.GroqProvider("llama3.3")
# action_model = providers.MistralProvider("large") # mistral-large-latest for non-vision tasks
# Action models using LiteLLM:
action_model = providers.LiteLLMProvider("large") # Mistral
# action_model = providers.LiteLLMProvider("llama3.3", provider="fireworks") # Fireworks
# action_model = providers.LiteLLMProvider("llama3.3", provider="groq") # Groq
# action_model = providers.LiteLLMProvider("gpt-4") # OpenAI
# action_model = providers.LiteLLMProvider("claude-3-5-sonnet") # Anthropic
# action_model = providers.LiteLLMProvider("gemini-2.0-flash", provider="gemini") # Gemini
147 changes: 80 additions & 67 deletions os_computer_use/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import re
import base64
import imghdr


def Message(content, role="assistant"):
Expand All @@ -22,6 +23,29 @@ def parse_json(s):
return None


def extract_json_objects(s):
"""Extract all balanced JSON objects from a string."""
objects = []
brace_level = 0
start_index = None
for i, char in enumerate(s):
if char == "{":
if brace_level == 0:
start_index = i
brace_level += 1
elif char == "}":
brace_level -= 1
if brace_level == 0 and start_index is not None:
candidate = s[start_index : i + 1]
try:
obj = json.loads(candidate)
objects.append(obj)
except json.JSONDecodeError:
pass
start_index = None
return objects


class LLMProvider:
"""
The LLM provider is used to make calls to an LLM given a provider and model name, with optional tool use support
Expand Down Expand Up @@ -52,6 +76,13 @@ def create_function_schema(self, definitions):
properties[param_name] = {"type": "string", "description": param_desc}
required.append(param_name)

# Add a dummy property if no parameters are provided, because providers like Gemini require a non-empty "properties" object.
if not properties:
properties["noop"] = {
"type": "string",
"description": "Dummy parameter for function with no parameters.",
}

function_def = self.create_function_def(name, details, properties, required)
functions.append(function_def)

Expand All @@ -68,8 +99,7 @@ def create_tool_call(self, name, parameters):
# Wrap a content block in a text or an image object
def wrap_block(self, block):
if isinstance(block, bytes):
encoded_image = base64.b64encode(block).decode("utf-8")
return self.create_image_block(encoded_image)
return self.create_image_block(block)
else:
return Text(block)

Expand Down Expand Up @@ -117,10 +147,17 @@ def create_function_def(self, name, details, properties, required):
},
}

def create_image_block(self, base64_image):
def create_image_block(self, image_data):
# Detect the image type using imghdr.
image_type = imghdr.what(None, image_data)
if image_type is None:
image_type = "png" # fallback if type cannot be detected

# Base64-encode the raw image bytes.
encoded = base64.b64encode(image_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
"image_url": {"url": f"data:image/{image_type};base64,{encoded}"},
}

def call(self, messages, functions=None):
Expand All @@ -140,18 +177,17 @@ def call(self, messages, functions=None):
if parse_json(tool_call.function.arguments) is not None
]

# Sometimes, function calls are returned unparsed by the inference provider. This code parses them manually.
# Sometimes, function calls are returned unparsed by the inference provider.
if message.content and not tool_calls:
tool_call_matches = re.search(r"\{.*\}", message.content)
if tool_call_matches:
tool_call = parse_json(tool_call_matches.group(0))
# Some models use "arguments" as the key instead of "parameters"
parameters = tool_call.get("parameters", tool_call.get("arguments"))
if tool_call.get("name") and parameters:
json_objs = extract_json_objects(message.content)
for obj in json_objs:
parameters = obj.get("parameters", obj.get("arguments"))
if obj.get("name") and parameters is not None:
combined_tool_calls.append(
self.create_tool_call(tool_call.get("name"), parameters)
self.create_tool_call(obj.get("name"), parameters)
)
return None, combined_tool_calls
if combined_tool_calls:
return None, combined_tool_calls

return message.content, combined_tool_calls

Expand All @@ -160,75 +196,52 @@ def call(self, messages, functions=None):
return message.content


class AnthropicBaseProvider(LLMProvider):
class LiteLLMBaseProvider(OpenAIBaseProvider):
"""Base provider using LiteLLM"""

def create_client(self):
return Anthropic(api_key=self.api_key).messages
from litellm import completion

def create_function_def(self, name, details, properties, required):
return {
"name": name,
"description": details["description"],
"input_schema": {
"type": "object",
"properties": properties,
"required": required,
},
}
import litellm

def create_image_block(self, base64_image):
return {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
}

def call(self, messages, functions=None):
tools = self.create_function_schema(functions) if functions else None

# Move all messages with the system role to a system parameter
system = "\n".join(
msg.get("content") for msg in messages if msg.get("role") == "system"
)
messages = [msg for msg in messages if msg.get("role") != "system"]

# Call the Anthropic API
completion = self.completion(
messages, system=system, tools=tools, max_tokens=4096
)
text = "".join(getattr(block, "text", "") for block in completion.content)
# Enable dropping unsupported params globally
litellm.drop_params = True
litellm.modify_params = True
# Enable debug mode for better error messages
# litellm._turn_on_debug()
return completion

# Return response text and tool calls separately
if functions:
tool_calls = [
self.create_tool_call(block.name, block.input)
for block in completion.content
if block.type == "tool_use"
]
return text, tool_calls
def completion(self, messages, **kwargs):
# Skip the tools parameter if it's None
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Only return response text
else:
return text
# No need to remove tools; pass tools so that function calling works with Claude.

# Wrap content blocks in image or text objects if necessary
new_messages = [self.transform_message(message) for message in messages]

class MistralBaseProvider(OpenAIBaseProvider):
def create_function_def(self, name, details, properties, required):
# If description is wrapped in a dict, extract the inner string
if isinstance(details.get("description"), dict):
details["description"] = details["description"].get("description", "")
return super().create_function_def(name, details, properties, required)
# Call LiteLLM completion
completion_response = self.client(
model=self.model,
messages=new_messages,
api_key=self.api_key,
**filtered_kwargs,
)
return completion_response

# Added method to adjust the final message role for Mistral-based models only
def call(self, messages, functions=None):
if messages and messages[-1].get("role") == "assistant":
if (
"mistral" in self.model.lower()
and messages
and messages[-1].get("role") == "assistant"
):
prefix = messages.pop()["content"]
if messages and messages[-1].get("role") == "user":
messages[-1]["content"] = (
prefix + "\n" + messages[-1].get("content", "")
)
else:
messages.append({"role": "user", "content": prefix})

return super().call(messages, functions)
Loading