Skip to content

Commit 021b637

Browse files
authored
Merge pull request #2403 from danielaskdd/azure-cot-handling
Refact: Consolidate Azure OpenAI and OpenAI implementations
2 parents 66d6c7d + ac9f257 commit 021b637

File tree

3 files changed

+370
-228
lines changed

3 files changed

+370
-228
lines changed

lightrag/llm/azure_openai.py

Lines changed: 17 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -1,193 +1,22 @@
1-
from collections.abc import Iterable
2-
import os
3-
import pipmaster as pm # Pipmaster for dynamic library install
1+
"""
2+
Azure OpenAI compatibility layer.
43
5-
# install specific modules
6-
if not pm.is_installed("openai"):
7-
pm.install("openai")
4+
This module provides backward compatibility by re-exporting Azure OpenAI functions
5+
from the main openai module where the actual implementation resides.
86
9-
from openai import (
10-
AsyncAzureOpenAI,
11-
APIConnectionError,
12-
RateLimitError,
13-
APITimeoutError,
14-
)
15-
from openai.types.chat import ChatCompletionMessageParam
16-
17-
from tenacity import (
18-
retry,
19-
stop_after_attempt,
20-
wait_exponential,
21-
retry_if_exception_type,
22-
)
23-
24-
from lightrag.utils import (
25-
wrap_embedding_func_with_attrs,
26-
safe_unicode_decode,
27-
logger,
28-
)
29-
from lightrag.types import GPTKeywordExtractionFormat
30-
31-
import numpy as np
32-
33-
34-
@retry(
35-
stop=stop_after_attempt(3),
36-
wait=wait_exponential(multiplier=1, min=4, max=10),
37-
retry=retry_if_exception_type(
38-
(RateLimitError, APIConnectionError, APIConnectionError)
39-
),
40-
)
41-
async def azure_openai_complete_if_cache(
42-
model,
43-
prompt,
44-
system_prompt: str | None = None,
45-
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
46-
enable_cot: bool = False,
47-
base_url: str | None = None,
48-
api_key: str | None = None,
49-
api_version: str | None = None,
50-
keyword_extraction: bool = False,
51-
**kwargs,
52-
):
53-
if enable_cot:
54-
logger.debug(
55-
"enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
56-
)
57-
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
58-
base_url = (
59-
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
60-
)
61-
api_key = (
62-
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
63-
)
64-
api_version = (
65-
api_version
66-
or os.getenv("AZURE_OPENAI_API_VERSION")
67-
or os.getenv("OPENAI_API_VERSION")
68-
)
69-
70-
kwargs.pop("hashing_kv", None)
71-
timeout = kwargs.pop("timeout", None)
72-
73-
# Handle keyword extraction mode
74-
if keyword_extraction:
75-
kwargs["response_format"] = GPTKeywordExtractionFormat
7+
All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai,
8+
with this module serving as a thin compatibility wrapper for existing code that
9+
imports from lightrag.llm.azure_openai.
10+
"""
7611

77-
openai_async_client = AsyncAzureOpenAI(
78-
azure_endpoint=base_url,
79-
azure_deployment=deployment,
80-
api_key=api_key,
81-
api_version=api_version,
82-
timeout=timeout,
83-
)
84-
messages = []
85-
if system_prompt:
86-
messages.append({"role": "system", "content": system_prompt})
87-
if history_messages:
88-
messages.extend(history_messages)
89-
if prompt is not None:
90-
messages.append({"role": "user", "content": prompt})
91-
92-
if "response_format" in kwargs:
93-
response = await openai_async_client.chat.completions.parse(
94-
model=model, messages=messages, **kwargs
95-
)
96-
else:
97-
response = await openai_async_client.chat.completions.create(
98-
model=model, messages=messages, **kwargs
99-
)
100-
101-
if hasattr(response, "__aiter__"):
102-
103-
async def inner():
104-
async for chunk in response:
105-
if len(chunk.choices) == 0:
106-
continue
107-
content = chunk.choices[0].delta.content
108-
if content is None:
109-
continue
110-
if r"\u" in content:
111-
content = safe_unicode_decode(content.encode("utf-8"))
112-
yield content
113-
114-
return inner()
115-
else:
116-
message = response.choices[0].message
117-
118-
# Handle parsed responses (structured output via response_format)
119-
# When using beta.chat.completions.parse(), the response is in message.parsed
120-
if hasattr(message, "parsed") and message.parsed is not None:
121-
# Serialize the parsed structured response to JSON
122-
content = message.parsed.model_dump_json()
123-
logger.debug("Using parsed structured response from API")
124-
else:
125-
# Handle regular content responses
126-
content = message.content
127-
if content and r"\u" in content:
128-
content = safe_unicode_decode(content.encode("utf-8"))
129-
130-
return content
131-
132-
133-
async def azure_openai_complete(
134-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
135-
) -> str:
136-
result = await azure_openai_complete_if_cache(
137-
os.getenv("LLM_MODEL", "gpt-4o-mini"),
138-
prompt,
139-
system_prompt=system_prompt,
140-
history_messages=history_messages,
141-
keyword_extraction=keyword_extraction,
142-
**kwargs,
143-
)
144-
return result
145-
146-
147-
@wrap_embedding_func_with_attrs(embedding_dim=1536)
148-
@retry(
149-
stop=stop_after_attempt(3),
150-
wait=wait_exponential(multiplier=1, min=4, max=10),
151-
retry=retry_if_exception_type(
152-
(RateLimitError, APIConnectionError, APITimeoutError)
153-
),
12+
from lightrag.llm.openai import (
13+
azure_openai_complete_if_cache,
14+
azure_openai_complete,
15+
azure_openai_embed,
15416
)
155-
async def azure_openai_embed(
156-
texts: list[str],
157-
model: str | None = None,
158-
base_url: str | None = None,
159-
api_key: str | None = None,
160-
api_version: str | None = None,
161-
) -> np.ndarray:
162-
deployment = (
163-
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
164-
or model
165-
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
166-
)
167-
base_url = (
168-
base_url
169-
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
170-
or os.getenv("EMBEDDING_BINDING_HOST")
171-
)
172-
api_key = (
173-
api_key
174-
or os.getenv("AZURE_EMBEDDING_API_KEY")
175-
or os.getenv("EMBEDDING_BINDING_API_KEY")
176-
)
177-
api_version = (
178-
api_version
179-
or os.getenv("AZURE_EMBEDDING_API_VERSION")
180-
or os.getenv("OPENAI_API_VERSION")
181-
)
182-
183-
openai_async_client = AsyncAzureOpenAI(
184-
azure_endpoint=base_url,
185-
azure_deployment=deployment,
186-
api_key=api_key,
187-
api_version=api_version,
188-
)
18917

190-
response = await openai_async_client.embeddings.create(
191-
model=model or deployment, input=texts, encoding_format="float"
192-
)
193-
return np.array([dp.embedding for dp in response.data])
18+
__all__ = [
19+
"azure_openai_complete_if_cache",
20+
"azure_openai_complete",
21+
"azure_openai_embed",
22+
]

0 commit comments

Comments
 (0)