11from typing import TYPE_CHECKING , AsyncGenerator , Iterator , Optional , Type
22
3+ from pydantic import BaseModel
34from theflow .utils .modules import import_dotted_string
45
5- from kotaemon .base import AIMessage , BaseMessage , HumanMessage , LLMInterface , Param , StructuredOutputLLMInterface
6+ from kotaemon .base import (
7+ AIMessage ,
8+ BaseMessage ,
9+ HumanMessage ,
10+ LLMInterface ,
11+ Param ,
12+ StructuredOutputLLMInterface ,
13+ )
614
715from .base import ChatLLM
8- from pydantic import BaseModel
916
1017if TYPE_CHECKING :
1118 from openai .types .chat .chat_completion_message_param import (
@@ -329,12 +336,14 @@ def openai_response(self, client, **kwargs):
329336 async def aopenai_response (self , client , ** kwargs ):
330337 params = self .prepare_params (** kwargs )
331338 return await client .chat .completions .create (** params )
332-
339+
333340
334341class StructuredOutputChatOpenAI (ChatOpenAI ):
335342 """OpenAI chat model that returns structured output"""
336- response_schema : Type [BaseModel ] = Param (help = "class that subclasses pydantics BaseModel" , required = True )
337-
343+
344+ response_schema : Type [BaseModel ] = Param (
345+ help = "class that subclasses pydantics BaseModel" , required = True
346+ )
338347
339348 def prepare_output (self , resp : dict ) -> StructuredOutputLLMInterface :
340349 """Convert the OpenAI response into StructuredOutputLLMInterface"""
@@ -354,7 +363,7 @@ def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
354363 )
355364
356365 output = StructuredOutputLLMInterface (
357- parsed = resp ["choices" ][0 ]["message" ]["parsed" ],
366+ parsed = resp ["choices" ][0 ]["message" ]["parsed" ],
358367 candidates = [(_ ["message" ]["content" ] or "" ) for _ in resp ["choices" ]],
359368 content = resp ["choices" ][0 ]["message" ]["content" ] or "" ,
360369 total_tokens = resp ["usage" ]["total_tokens" ],
@@ -366,11 +375,10 @@ def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
366375 ],
367376 additional_kwargs = additional_kwargs ,
368377 logprobs = logprobs ,
369-
370378 )
371379
372380 return output
373-
381+
374382 def prepare_params (self , ** kwargs ):
375383 if "tools_pydantic" in kwargs :
376384 kwargs .pop ("tools_pydantic" )
@@ -395,23 +403,21 @@ def prepare_params(self, **kwargs):
395403 params .update (kwargs )
396404
397405 # doesn't do streaming
398- params .pop (' stream' )
406+ params .pop (" stream" )
399407
400408 return params
401-
409+
402410 def openai_response (self , client , ** kwargs ):
403411 """Get the openai response"""
404412 params = self .prepare_params (** kwargs )
405413
406414 return client .beta .chat .completions .parse (** params )
407-
408415
409416 async def aopenai_response (self , client , ** kwargs ):
410417 """Get the openai response"""
411418 params = self .prepare_params (** kwargs )
412-
413- return await client .beta .chat .completions .parse (** params )
414419
420+ return await client .beta .chat .completions .parse (** params )
415421
416422
417423class AzureChatOpenAI (BaseChatOpenAI ):
0 commit comments