-
Notifications
You must be signed in to change notification settings - Fork 735
Passing Functions as Tools #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4383603
afe7db6
8fee892
0e5a940
1ef75a7
93c7a63
e5dc2b8
aa20015
d79538e
97aa167
8ec5123
efb775b
2efa54a
1f089f7
fe8d143
67321a8
2cc0b40
e68700c
f452fab
ca16670
7dcb598
7c5c294
16c868a
718412a
e7bb55f
7396ab6
0d9eec0
ed3ba8a
a4ec34a
6d9c156
c5c61a3
b0e0409
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
from collections import defaultdict | ||
import inspect | ||
from typing import Callable, Union | ||
import re | ||
|
||
import pydantic | ||
from ollama._types import Tool | ||
|
||
|
||
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: | ||
parsed_docstring = defaultdict(str) | ||
if not doc_string: | ||
return parsed_docstring | ||
|
||
key = hash(doc_string) | ||
for line in doc_string.splitlines(): | ||
lowered_line = line.lower().strip() | ||
if lowered_line.startswith('args:'): | ||
key = 'args' | ||
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'): | ||
key = '_' | ||
|
||
else: | ||
# maybe change to a list and join later | ||
parsed_docstring[key] += f'{line.strip()}\n' | ||
|
||
last_key = None | ||
for line in parsed_docstring['args'].splitlines(): | ||
line = line.strip() | ||
if ':' in line: | ||
# Split the line on either: | ||
# 1. A parenthetical expression like (integer) - captured in group 1 | ||
# 2. A colon : | ||
# Followed by optional whitespace. Only split on first occurrence. | ||
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tend to avoid regexp when possible since it's hard to grok. In this scenario, a simpler solution would be to split on the mandatory for line in parsed_docstring['args'].splitlines():
pre, _, post = line.partition(':')
if not pre.strip():
continue
if not post.strip() and last_key:
parsed_docstring[last_key] += ' ' + pre
continue
arg_name, _, _ = pre.replace('(', ' ').partition(' ')
last_key = arg_name.strip()
parsed_docstring[last_key] = post.strip() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO @ParthSareen to spin out issue |
||
|
||
arg_name = parts[0].strip() | ||
last_key = arg_name | ||
|
||
# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon | ||
arg_description = parts[-1].strip() | ||
if len(parts) > 2 and parts[1]: # Has parenthetical content | ||
arg_description = parts[-1].split(':', 1)[-1].strip() | ||
|
||
parsed_docstring[last_key] = arg_description | ||
|
||
elif last_key and line: | ||
parsed_docstring[last_key] += ' ' + line | ||
|
||
return parsed_docstring | ||
|
||
|
||
def convert_function_to_tool(func: Callable) -> Tool: | ||
doc_string_hash = hash(inspect.getdoc(func)) | ||
parsed_docstring = _parse_docstring(inspect.getdoc(func)) | ||
schema = type( | ||
func.__name__, | ||
(pydantic.BaseModel,), | ||
{ | ||
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()}, | ||
'__signature__': inspect.signature(func), | ||
'__doc__': parsed_docstring[doc_string_hash], | ||
}, | ||
).model_json_schema() | ||
|
||
for k, v in schema.get('properties', {}).items(): | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# If type is missing, the default is string | ||
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')} | ||
if 'null' in types: | ||
schema['required'].remove(k) | ||
types.discard('null') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is okay, IMO something like: def (a:None, b:type(None)):
... is extremely unlikely |
||
|
||
schema['properties'][k] = { | ||
'description': parsed_docstring[k], | ||
'type': ', '.join(types), | ||
} | ||
|
||
tool = Tool( | ||
function=Tool.Function( | ||
name=func.__name__, | ||
description=schema.get('description', ''), | ||
parameters=Tool.Function.Parameters(**schema), | ||
) | ||
) | ||
|
||
return Tool.model_validate(tool) |
Uh oh!
There was an error while loading. Please reload this page.