Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

fix: only record db content if it is the last chunk in stream #1091

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
active_workspace = await DbReader().get_active_workspace()
workspace_id = active_workspace.id if active_workspace else "1"
prompt_params.workspace_id = workspace_id

sql = text(
"""
INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id)
Expand Down Expand Up @@ -302,7 +303,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
await self.record_outputs(context.output_responses, initial_id)
await self.record_alerts(context.alerts_raised, initial_id)
logger.info(
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Updated context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
except Exception as e:
Expand Down
10 changes: 7 additions & 3 deletions src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def _record_to_db(self) -> None:
loop.create_task(self._db_recorder.record_context(self._input_context))

async def process_stream(
self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True
self,
stream: AsyncIterator[ModelResponse],
cleanup_sensitive: bool = True,
finish_stream: bool = True,
) -> AsyncIterator[ModelResponse]:
"""
Process a stream through all pipeline steps
Expand Down Expand Up @@ -167,7 +170,7 @@ async def process_stream(
finally:
# NOTE: Don't use await in finally block, it will break the stream
# Don't flush the buffer if we assume we'll call the pipeline again
if cleanup_sensitive is False:
if cleanup_sensitive is False and finish_stream:
self._record_to_db()
return

Expand All @@ -194,7 +197,8 @@ async def process_stream(
yield chunk
self._context.buffer.clear()

self._record_to_db()
if finish_stream:
self._record_to_db()
# Cleanup sensitive data through the input context
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
self._input_context.sensitive.secure_cleanup()
Expand Down
10 changes: 9 additions & 1 deletion src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,8 +905,16 @@ async def stream_iterator():
)
yield mr

# needs to be set as the flag gets reset on finish_data
finish_stream_flag = any(
choice.get("finish_reason") == "stop"
for record in list(self.stream_queue._queue)
for choice in record.get("content", {}).get("choices", [])
)
async for record in self.output_pipeline_instance.process_stream(
stream_iterator(), cleanup_sensitive=False
stream_iterator(),
cleanup_sensitive=False,
finish_stream=finish_stream_flag,
):
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
sse_data = f"data: {chunk}\n\n".encode("utf-8")
Expand Down
Loading