Skip to content

Commit 7108c04

Browse files
authored
Add database agent (#1540)
* Add database agent * qol stuff * add planner db agent to planner docstring
1 parent b230bdd commit 7108c04

File tree

18 files changed

+474
-37
lines changed

18 files changed

+474
-37
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import asyncio
2+
from functools import partial
3+
from typing import Any, Optional, Union
4+
5+
from pydantic import BaseModel
6+
from pydantic_ai import Agent
7+
from pydantic_ai.agent import AgentRunResult
8+
9+
from patchwork.common.client.llm.protocol import LlmClient
10+
from patchwork.common.client.llm.utils import example_json_to_base_model
11+
from patchwork.common.tools import Tool
12+
13+
14+
class StepCompletedResult(BaseModel):
15+
is_step_completed: bool
16+
17+
18+
class PlanCompletedResult(BaseModel):
19+
is_plan_completed: bool
20+
21+
22+
class ExecutionResult(BaseModel):
23+
json_data: str
24+
message: str
25+
is_completed: bool
26+
27+
28+
class _Plan:
29+
def __init__(self, initial_plan: Optional[list[str]] = None):
30+
self.__plan = initial_plan or []
31+
self.__cursor = 0
32+
33+
def advance(self) -> bool:
34+
self.__cursor += 1
35+
return self.__cursor < len(self.__plan)
36+
37+
def is_empty(self) -> bool:
38+
return len(self.__plan) == 0
39+
40+
def register_steps(self, agent: Agent):
41+
agent.tool_plain(self.get_current_plan)
42+
agent.tool_plain(self.get_current_step)
43+
agent.tool_plain(self.get_current_step_index)
44+
agent.tool_plain(self.add_step)
45+
agent.tool_plain(self.delete_step)
46+
47+
def get_current_plan(self) -> str:
48+
return "\n".join([f"{i}. {step}" for i, step in enumerate(self.__plan)])
49+
50+
def get_current_step(self) -> str:
51+
if len(self.__plan) == 0:
52+
return "There is currently no plan"
53+
54+
return self.__plan[self.__cursor]
55+
56+
def get_current_step_index(self) -> int:
57+
return self.__cursor
58+
59+
def add_step(self, index: int, step: str) -> str:
60+
if index < 0:
61+
return "index cannot be a negative number"
62+
63+
if index >= len(self.__plan):
64+
insertion_func = self.__plan.append
65+
else:
66+
insertion_func = partial(self.__plan.insert, index)
67+
68+
insertion_func(step)
69+
return "Added step\nCurrent plan:\n" + self.get_current_plan()
70+
71+
def delete_step(self, step: str) -> str:
72+
try:
73+
i = self.__plan.index(step)
74+
self.__plan.pop(i)
75+
return self.get_current_plan()
76+
except ValueError:
77+
return "Step not found in plan\nCurrent plan:\n" + self.get_current_plan()
78+
79+
80+
class PlanningStrategy:
81+
def __init__(
82+
self,
83+
llm_client: LlmClient,
84+
planner_system_prompt: str,
85+
executor_system_prompt: str,
86+
executor_tool_set: dict[str, Tool],
87+
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
88+
):
89+
"""
90+
Use this like this::
91+
92+
class DatabaseAgent(Step, input_class=DatabaseAgentInputs, output_class=DatabaseAgentOutputs):
93+
def __init__(self, inputs):
94+
super().__init__(inputs)
95+
96+
llm_client = AioLlmClient.create_aio_client(inputs)
97+
98+
data = inputs.get("prompt_value", {})
99+
self.task = mustache_render(inputs["task"], data)
100+
101+
db_dialect = inputs["db_dialect"]
102+
self.planner = PlanningStrategy(
103+
llm_client,
104+
planner_system_prompt=f'''\\
105+
You are a {db_dialect} database query planning assistant. You are tasked to plan the steps to assist with the provided task.
106+
You will not execute the steps in the plan. The user will do that instead.
107+
The first step of the plan should be as follows:
108+
1. Tell me all tables currently available.
109+
After the list of table names is provided, get the DDL of the tables that is relevant.
110+
Your steps should be clear and concise like the following example:
111+
1. Tell me the column descriptions of the table `orders`.
112+
2. Execute the SQL Query: `SELECT * FROM orders`
113+
After every step, you will be asked to edit the plan so feel free to plan 1 step at a time.
114+
''',
115+
executor_system_prompt=f'''\\
116+
You are a {db_dialect} database query execution assistant. You will be provided instructions on what to do.
117+
''',
118+
)
119+
120+
def run(self) -> dict:
121+
planner_response = self.planner.run(self.task, 10)
122+
return {**planner_response, **self.planner.usage()}
123+
124+
"""
125+
self.planner = Agent(
126+
llm_client,
127+
name="Planner",
128+
system_prompt=planner_system_prompt,
129+
model_settings=dict(
130+
parallel_tool_calls=False,
131+
model="gemini-2.0-flash",
132+
),
133+
)
134+
135+
self.plan = _Plan()
136+
self.plan.register_steps(self.planner)
137+
138+
self.executor = Agent(
139+
llm_client,
140+
name="Executor",
141+
system_prompt=executor_system_prompt,
142+
result_type=ExecutionResult,
143+
tools=[tool.to_pydantic_ai_function_tool() for tool in executor_tool_set.values()],
144+
model_settings=dict(
145+
parallel_tool_calls=False,
146+
model="gemini-2.0-flash",
147+
),
148+
)
149+
150+
self.__summariser = Agent(
151+
llm_client,
152+
result_retries=5,
153+
system_prompt="""\
154+
Please summarise the conversation given and provide the result in the structure that is asked of you.
155+
""",
156+
result_type=example_json_to_base_model(example_json),
157+
model_settings=dict(
158+
parallel_tool_calls=False,
159+
model="gemini-2.0-flash",
160+
),
161+
)
162+
163+
self.reset()
164+
165+
def reset(self):
166+
self.__request_tokens = 0
167+
self.__response_tokens = 0
168+
169+
def usage(self):
170+
return {
171+
"request_tokens": self.__request_tokens,
172+
"response_tokens": self.__response_tokens,
173+
}
174+
175+
def __agent_run(self, agent: Agent, prompt: str, **kwargs) -> AgentRunResult[Any]:
176+
loop = asyncio.new_event_loop()
177+
planner_response = loop.run_until_complete(agent.run(prompt, **kwargs))
178+
loop.close()
179+
self.__request_tokens += planner_response.usage().request_tokens
180+
self.__response_tokens += planner_response.usage().response_tokens
181+
182+
return planner_response
183+
184+
def run(self, task: str, conversation_limit: int = 10) -> dict:
185+
186+
planner_response = self.__agent_run(self.planner, f"Produce the initial plan for {task}")
187+
planner_history = planner_response.all_messages()
188+
if self.plan.is_empty():
189+
planner_response = self.__agent_run(
190+
self.planner, f"Please use the tools provided to setup the plan", message_history=planner_history
191+
)
192+
planner_history = planner_response.all_messages()
193+
194+
for i in range(conversation_limit):
195+
step = self.plan.get_current_step()
196+
executor_prompt = f"Please execute the following task: {step}"
197+
response = self.__agent_run(self.executor, executor_prompt)
198+
199+
plan_str = self.plan.get_current_plan()
200+
step_index = self.plan.get_current_step_index()
201+
planner_prompt = f"""\
202+
The current plan is:
203+
{plan_str}
204+
205+
We are current at {step_index}.
206+
If the current step is not completed, edit the current step.
207+
208+
The execution result for the step {step_index} is:
209+
{response.data}
210+
211+
"""
212+
planner_response = self.__agent_run(
213+
self.planner,
214+
planner_prompt,
215+
message_history=planner_history,
216+
result_type=StepCompletedResult,
217+
)
218+
planner_history = planner_response.all_messages()
219+
if not planner_response.data.is_step_completed:
220+
continue
221+
222+
if self.plan.advance():
223+
continue
224+
225+
planner_response = self.__agent_run(
226+
self.planner,
227+
"Is the task completed? If the task is not completed please add more steps using the tools provided.",
228+
message_history=planner_history,
229+
result_type=PlanCompletedResult,
230+
)
231+
if planner_response.data.is_plan_completed:
232+
break
233+
234+
final_result = self.__agent_run(
235+
self.__summariser,
236+
"From the actions taken by the assistant. Please give me the result.",
237+
message_history=planner_history,
238+
)
239+
240+
return final_result.data.dict()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing_extensions import Any, Union
2+
3+
from patchwork.common.tools import Tool
4+
from patchwork.steps import CallSQL
5+
6+
7+
class DatabaseQueryTool(Tool, tool_name="db_query_tool"):
8+
def __init__(self, inputs: dict[str, Any]):
9+
super().__init__()
10+
self.db_settings = inputs.copy()
11+
self.db_dialect = inputs.get("db_dialect", "SQL")
12+
13+
@property
14+
def json_schema(self) -> dict:
15+
return {
16+
"name": "db_query_tool",
17+
"description": f"""\
18+
Run SQL Query on current {self.db_dialect} database.
19+
""",
20+
"input_schema": {
21+
"type": "object",
22+
"properties": {
23+
"query": {
24+
"type": "string",
25+
"description": "Database query to run.",
26+
}
27+
},
28+
"required": ["query"],
29+
},
30+
}
31+
32+
def execute(self, query: str) -> Union[list[dict[str, Any]], str]:
33+
db_settings = self.db_settings.copy()
34+
db_settings["db_query"] = query
35+
try:
36+
return CallSQL(db_settings).run().get("results", [])
37+
except Exception as e:
38+
return str(e)

patchwork/steps/AgenticLLM/typed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ class AgenticLLMInputs(TypedDict, total=False):
1111
user_prompt: str
1212
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
1313
openai_api_key: Annotated[
14-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
14+
str,
15+
StepTypeConfig(
16+
is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"]
17+
),
1518
]
1619
anthropic_api_key: Annotated[
17-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
20+
str,
21+
StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"]),
1822
]
1923
patched_api_key: Annotated[
2024
str,
@@ -31,10 +35,16 @@ class AgenticLLMInputs(TypedDict, total=False):
3135
),
3236
]
3337
google_api_key: Annotated[
34-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
38+
str,
39+
StepTypeConfig(
40+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"]
41+
),
3542
]
3643
client_is_gcp: Annotated[
37-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
44+
str,
45+
StepTypeConfig(
46+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"]
47+
),
3848
]
3949

4050

patchwork/steps/CallLLM/typed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ class CallLLMInputs(TypedDict, total=False):
1313
model_args: Annotated[str, StepTypeConfig(is_config=True)]
1414
client_args: Annotated[str, StepTypeConfig(is_config=True)]
1515
openai_api_key: Annotated[
16-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
16+
str,
17+
StepTypeConfig(
18+
is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"]
19+
),
1720
]
1821
anthropic_api_key: Annotated[
19-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
22+
str,
23+
StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"]),
2024
]
2125
patched_api_key: Annotated[
2226
str,
@@ -33,10 +37,16 @@ class CallLLMInputs(TypedDict, total=False):
3337
),
3438
]
3539
google_api_key: Annotated[
36-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
40+
str,
41+
StepTypeConfig(
42+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"]
43+
),
3744
]
3845
client_is_gcp: Annotated[
39-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
46+
str,
47+
StepTypeConfig(
48+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"]
49+
),
4050
]
4151
file: Annotated[str, StepTypeConfig(is_path=True)]
4252

0 commit comments

Comments
 (0)