diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 9c9abd79..22466344 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -129,6 +129,7 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option active_workspace = await DbReader().get_active_workspace() workspace_id = active_workspace.id if active_workspace else "1" prompt_params.workspace_id = workspace_id + sql = text( """ INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id) @@ -302,7 +303,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: await self.record_outputs(context.output_responses, initial_id) await self.record_alerts(context.alerts_raised, initial_id) logger.info( - f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " + f"Updated context in DB. Output chunks: {len(context.output_responses)}. " f"Alerts: {len(context.alerts_raised)}." ) except Exception as e: diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 6f990e03..16672a60 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -127,7 +127,10 @@ def _record_to_db(self) -> None: loop.create_task(self._db_recorder.record_context(self._input_context)) async def process_stream( - self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True + self, + stream: AsyncIterator[ModelResponse], + cleanup_sensitive: bool = True, + finish_stream: bool = True, ) -> AsyncIterator[ModelResponse]: """ Process a stream through all pipeline steps @@ -167,7 +170,7 @@ async def process_stream( finally: # NOTE: Don't use await in finally block, it will break the stream # Don't flush the buffer if we assume we'll call the pipeline again - if cleanup_sensitive is False: + if cleanup_sensitive is False and finish_stream: self._record_to_db() return @@ -194,7 +197,8 @@ async def process_stream( yield chunk self._context.buffer.clear() - self._record_to_db() + if finish_stream: + self._record_to_db() # Cleanup sensitive data through the input context if cleanup_sensitive and self._input_context and self._input_context.sensitive: self._input_context.sensitive.secure_cleanup() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index bf711210..4e1c6f1d 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -905,8 +905,16 @@ async def stream_iterator(): ) yield mr + # needs to be set as the flag gets reset on finish_data + finish_stream_flag = any( + choice.get("finish_reason") == "stop" + for record in list(self.stream_queue._queue) + for choice in record.get("content", {}).get("choices", []) + ) async for record in self.output_pipeline_instance.process_stream( - stream_iterator(), cleanup_sensitive=False + stream_iterator(), + cleanup_sensitive=False, + finish_stream=finish_stream_flag, ): chunk = record.model_dump_json(exclude_none=True, exclude_unset=True) sse_data = f"data: {chunk}\n\n".encode("utf-8")