|
1 | 1 | from __future__ import annotations
|
| 2 | +from collections import defaultdict |
2 | 3 | import inspect
|
3 | 4 | from typing import Callable, Union
|
4 | 5 |
|
|
7 | 8 |
|
8 | 9 |
|
9 | 10 | def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
|
10 |
| - parsed_docstring = {'description': ''} |
| 11 | + parsed_docstring = defaultdict(str) |
11 | 12 | if not doc_string:
|
12 | 13 | return parsed_docstring
|
13 | 14 |
|
14 | 15 | lowered_doc_string = doc_string.lower()
|
15 | 16 |
|
16 |
| - if 'args:' not in lowered_doc_string: |
17 |
| - parsed_docstring['description'] = lowered_doc_string.strip() |
18 |
| - return parsed_docstring |
19 |
| - |
20 |
| - else: |
21 |
| - parsed_docstring['description'] = lowered_doc_string.split('args:')[0].strip() |
22 |
| - args_section = lowered_doc_string.split('args:')[1] |
23 |
| - |
24 |
| - if 'returns:' in lowered_doc_string: |
25 |
| - # Return section can be captured and used |
26 |
| - args_section = args_section.split('returns:')[0] |
| 17 | + # change name |
| 18 | + key = 'func_description' |
| 19 | + parsed_docstring[key] = '' |
| 20 | + for line in lowered_doc_string.splitlines(): |
| 21 | + if line.startswith('args:'): |
| 22 | + key = 'args' |
| 23 | + elif line.startswith('returns:') or line.startswith('yields:') or line.startswith('raises:'): |
| 24 | + key = '_' |
27 | 25 |
|
28 |
| - if 'yields:' in lowered_doc_string: |
29 |
| - args_section = args_section.split('yields:')[0] |
| 26 | + else: |
| 27 | + # maybe change to a list and join later |
| 28 | + parsed_docstring[key] += f'{line.strip()}\n' |
30 | 29 |
|
31 |
| - cur_var = None |
32 |
| - for line in args_section.split('\n'): |
| 30 | + last_key = None |
| 31 | + for line in parsed_docstring['args'].splitlines(): |
33 | 32 | line = line.strip()
|
34 |
| - if not line: |
35 |
| - continue |
36 |
| - if ':' not in line: |
37 |
| - # Continuation of the previous parameter's description |
38 |
| - if cur_var: |
39 |
| - parsed_docstring[cur_var] += f' {line}' |
40 |
| - continue |
41 |
| - |
42 |
| - # For the case with: `param_name (type)`: ... |
43 |
| - if '(' in line: |
44 |
| - param_name = line.split('(')[0] |
45 |
| - param_desc = line.split('):')[1] |
46 |
| - |
47 |
| - # For the case with: `param_name: ...` |
48 |
| - else: |
49 |
| - param_name, param_desc = line.split(':', 1) |
| 33 | + if ':' in line and not line.startswith('args'): |
| 34 | + # Split on first occurrence of '(' or ':' to separate arg name from description |
| 35 | + split_char = '(' if '(' in line else ':' |
| 36 | + arg_name, rest = line.split(split_char, 1) |
| 37 | + |
| 38 | + last_key = arg_name.strip() |
| 39 | + # Get description after the colon |
| 40 | + arg_description = rest.split(':', 1)[1].strip() if split_char == '(' else rest.strip() |
| 41 | + parsed_docstring[last_key] = arg_description |
50 | 42 |
|
51 |
| - parsed_docstring[param_name.strip()] = param_desc.strip() |
52 |
| - cur_var = param_name.strip() |
| 43 | + elif last_key and line: |
| 44 | + parsed_docstring[last_key] += ' ' + line |
53 | 45 |
|
54 | 46 | return parsed_docstring
|
55 | 47 |
|
56 | 48 |
|
57 | 49 | def convert_function_to_tool(func: Callable) -> Tool:
|
| 50 | + parsed_docstring = _parse_docstring(inspect.getdoc(func)) |
58 | 51 | schema = type(
|
59 | 52 | func.__name__,
|
60 | 53 | (pydantic.BaseModel,),
|
61 | 54 | {
|
62 |
| - '__annotations__': {k: v.annotation for k, v in inspect.signature(func).parameters.items()}, |
| 55 | + '__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()}, |
63 | 56 | '__signature__': inspect.signature(func),
|
64 |
| - '__doc__': inspect.getdoc(func), |
| 57 | + '__doc__': parsed_docstring.get('func_description', ''), |
65 | 58 | },
|
66 | 59 | ).model_json_schema()
|
67 | 60 |
|
68 |
| - properties = {} |
69 |
| - required = [] |
70 |
| - parsed_docstring = _parse_docstring(schema.get('description')) |
71 | 61 | for k, v in schema.get('properties', {}).items():
|
72 |
| - prop = { |
| 62 | + # think about how no type is handled |
| 63 | + types = {t.get('type', 'string') for t in v.get('anyOf', [])} if 'anyOf' in v else {v.get('type', 'string')} |
| 64 | + if 'null' in types: |
| 65 | + schema['required'].remove(k) |
| 66 | + types.discard('null') |
| 67 | + |
| 68 | + schema['properties'][k] = { |
73 | 69 | 'description': parsed_docstring.get(k, ''),
|
74 |
| - 'type': v.get('type'), |
| 70 | + 'type': ', '.join(types), |
75 | 71 | }
|
76 | 72 |
|
77 |
| - if 'anyOf' in v: |
78 |
| - is_optional = any(t.get('type') == 'null' for t in v['anyOf']) |
79 |
| - types = [t.get('type', 'string') for t in v['anyOf'] if t.get('type') != 'null'] |
80 |
| - prop['type'] = types[0] if len(types) == 1 else str(types) |
81 |
| - if not is_optional: |
82 |
| - required.append(k) |
83 |
| - else: |
84 |
| - if prop['type'] != 'null': |
85 |
| - required.append(k) |
86 |
| - |
87 |
| - properties[k] = prop |
88 |
| - |
89 |
| - schema['properties'] = properties |
90 |
| - |
91 | 73 | tool = Tool(
|
92 | 74 | function=Tool.Function(
|
93 | 75 | name=func.__name__,
|
94 |
| - description=parsed_docstring.get('description'), |
95 |
| - parameters=Tool.Function.Parameters( |
96 |
| - type='object', |
97 |
| - properties=schema.get('properties', {}), |
98 |
| - required=required, |
99 |
| - ), |
| 76 | + description=schema.get('description', ''), |
| 77 | + parameters=Tool.Function.Parameters(**schema), |
100 | 78 | )
|
101 | 79 | )
|
102 | 80 |
|
|
0 commit comments