diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 3edced69..d22afcc0 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -133,12 +133,23 @@ async def _run_output_stream_pipeline( denormalized_stream = self._output_normalizer.denormalize_streaming(pipeline_output_stream) return denormalized_stream - def _run_output_pipeline( + async def _run_output_pipeline( self, - normalized_response: ModelResponse, + input_context: PipelineContext, + model_response: Any, ) -> ModelResponse: - # we don't have a pipeline for non-streamed output yet - return normalized_response + """ + Run the output pipeline for a single response. + + For the moment we don't have a pipeline for non-streamed output, so we + just normalize the response and record the context. It is done here to match + the behaviour of the streaming pipeline. + """ + normalized_response = self._output_normalizer.normalize(model_response) + input_context.add_output(normalized_response) + await self._db_recorder.record_context(input_context) + output_result = self._output_normalizer.denormalize(normalized_response) + return output_result async def _run_input_pipeline( self, @@ -263,10 +274,7 @@ async def complete( is_fim_request=is_fim_request, ) if not streaming: - normalized_response = self._output_normalizer.normalize(model_response) - pipeline_output = self._run_output_pipeline(normalized_response) - await self._db_recorder.record_context(input_pipeline_result.context) - return self._output_normalizer.denormalize(pipeline_output) + return await self._run_output_pipeline(input_pipeline_result.context, model_response) pipeline_output_stream = await self._run_output_stream_pipeline( input_pipeline_result.context, model_response, is_fim_request=is_fim_request # type: ignore