Skip to content

Commit 6adaac7

Browse files
committed
fix: comfort precommit
1 parent a016a28 commit 6adaac7

File tree

4 files changed

+23
-16
lines changed

4 files changed

+23
-16
lines changed

libs/kotaemon/kotaemon/base/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
ExtractorOutput,
88
HumanMessage,
99
LLMInterface,
10-
StructuredOutputLLMInterface,
1110
RetrievedDocument,
11+
StructuredOutputLLMInterface,
1212
SystemMessage,
1313
)
1414

libs/kotaemon/kotaemon/base/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,10 @@ class LLMInterface(AIMessage):
142142
messages: list[AIMessage] = Field(default_factory=list)
143143
logprobs: list[float] = []
144144

145+
145146
class StructuredOutputLLMInterface(LLMInterface):
146147
parsed: Any
147-
refusal: str = ''
148+
refusal: str = ""
148149

149150

150151
class ExtractorOutput(Document):

libs/kotaemon/kotaemon/llms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
AzureChatOpenAI,
77
ChatLLM,
88
ChatOpenAI,
9-
StructuredOutputChatOpenAI,
109
EndpointChatLLM,
1110
LCAnthropicChat,
1211
LCAzureChatOpenAI,
1312
LCChatOpenAI,
1413
LCCohereChat,
1514
LCGeminiChat,
1615
LlamaCppChat,
16+
StructuredOutputChatOpenAI,
1717
)
1818
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
1919
from .cot import ManualSequentialChainOfThought, Thought

libs/kotaemon/kotaemon/llms/chats/openai.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional, Type
22

3+
from pydantic import BaseModel
34
from theflow.utils.modules import import_dotted_string
45

5-
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param, StructuredOutputLLMInterface
6+
from kotaemon.base import (
7+
AIMessage,
8+
BaseMessage,
9+
HumanMessage,
10+
LLMInterface,
11+
Param,
12+
StructuredOutputLLMInterface,
13+
)
614

715
from .base import ChatLLM
8-
from pydantic import BaseModel
916

1017
if TYPE_CHECKING:
1118
from openai.types.chat.chat_completion_message_param import (
@@ -329,12 +336,14 @@ def openai_response(self, client, **kwargs):
329336
async def aopenai_response(self, client, **kwargs):
330337
params = self.prepare_params(**kwargs)
331338
return await client.chat.completions.create(**params)
332-
339+
333340

334341
class StructuredOutputChatOpenAI(ChatOpenAI):
335342
"""OpenAI chat model that returns structured output"""
336-
response_schema: Type[BaseModel] = Param(help="class that subclasses pydantics BaseModel", required = True)
337-
343+
344+
response_schema: Type[BaseModel] = Param(
345+
help="class that subclasses pydantics BaseModel", required=True
346+
)
338347

339348
def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
340349
"""Convert the OpenAI response into StructuredOutputLLMInterface"""
@@ -354,7 +363,7 @@ def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
354363
)
355364

356365
output = StructuredOutputLLMInterface(
357-
parsed = resp["choices"][0]["message"]["parsed"],
366+
parsed=resp["choices"][0]["message"]["parsed"],
358367
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
359368
content=resp["choices"][0]["message"]["content"] or "",
360369
total_tokens=resp["usage"]["total_tokens"],
@@ -366,11 +375,10 @@ def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
366375
],
367376
additional_kwargs=additional_kwargs,
368377
logprobs=logprobs,
369-
370378
)
371379

372380
return output
373-
381+
374382
def prepare_params(self, **kwargs):
375383
if "tools_pydantic" in kwargs:
376384
kwargs.pop("tools_pydantic")
@@ -395,23 +403,21 @@ def prepare_params(self, **kwargs):
395403
params.update(kwargs)
396404

397405
# doesn't do streaming
398-
params.pop('stream')
406+
params.pop("stream")
399407

400408
return params
401-
409+
402410
def openai_response(self, client, **kwargs):
403411
"""Get the openai response"""
404412
params = self.prepare_params(**kwargs)
405413

406414
return client.beta.chat.completions.parse(**params)
407-
408415

409416
async def aopenai_response(self, client, **kwargs):
410417
"""Get the openai response"""
411418
params = self.prepare_params(**kwargs)
412-
413-
return await client.beta.chat.completions.parse(**params)
414419

420+
return await client.beta.chat.completions.parse(**params)
415421

416422

417423
class AzureChatOpenAI(BaseChatOpenAI):

0 commit comments

Comments
 (0)