|
1 | | -from collections.abc import Iterable |
2 | | -import os |
3 | | -import pipmaster as pm # Pipmaster for dynamic library install |
| 1 | +""" |
| 2 | +Azure OpenAI compatibility layer. |
4 | 3 |
|
5 | | -# install specific modules |
6 | | -if not pm.is_installed("openai"): |
7 | | - pm.install("openai") |
| 4 | +This module provides backward compatibility by re-exporting Azure OpenAI functions |
| 5 | +from the main openai module where the actual implementation resides. |
8 | 6 |
|
9 | | -from openai import ( |
10 | | - AsyncAzureOpenAI, |
11 | | - APIConnectionError, |
12 | | - RateLimitError, |
13 | | - APITimeoutError, |
14 | | -) |
15 | | -from openai.types.chat import ChatCompletionMessageParam |
16 | | - |
17 | | -from tenacity import ( |
18 | | - retry, |
19 | | - stop_after_attempt, |
20 | | - wait_exponential, |
21 | | - retry_if_exception_type, |
22 | | -) |
23 | | - |
24 | | -from lightrag.utils import ( |
25 | | - wrap_embedding_func_with_attrs, |
26 | | - safe_unicode_decode, |
27 | | - logger, |
28 | | -) |
29 | | -from lightrag.types import GPTKeywordExtractionFormat |
30 | | - |
31 | | -import numpy as np |
32 | | - |
33 | | - |
34 | | -@retry( |
35 | | - stop=stop_after_attempt(3), |
36 | | - wait=wait_exponential(multiplier=1, min=4, max=10), |
37 | | - retry=retry_if_exception_type( |
38 | | - (RateLimitError, APIConnectionError, APIConnectionError) |
39 | | - ), |
40 | | -) |
41 | | -async def azure_openai_complete_if_cache( |
42 | | - model, |
43 | | - prompt, |
44 | | - system_prompt: str | None = None, |
45 | | - history_messages: Iterable[ChatCompletionMessageParam] | None = None, |
46 | | - enable_cot: bool = False, |
47 | | - base_url: str | None = None, |
48 | | - api_key: str | None = None, |
49 | | - api_version: str | None = None, |
50 | | - keyword_extraction: bool = False, |
51 | | - **kwargs, |
52 | | -): |
53 | | - if enable_cot: |
54 | | - logger.debug( |
55 | | - "enable_cot=True is not supported for the Azure OpenAI API and will be ignored." |
56 | | - ) |
57 | | - deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") |
58 | | - base_url = ( |
59 | | - base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") |
60 | | - ) |
61 | | - api_key = ( |
62 | | - api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") |
63 | | - ) |
64 | | - api_version = ( |
65 | | - api_version |
66 | | - or os.getenv("AZURE_OPENAI_API_VERSION") |
67 | | - or os.getenv("OPENAI_API_VERSION") |
68 | | - ) |
69 | | - |
70 | | - kwargs.pop("hashing_kv", None) |
71 | | - timeout = kwargs.pop("timeout", None) |
72 | | - |
73 | | - # Handle keyword extraction mode |
74 | | - if keyword_extraction: |
75 | | - kwargs["response_format"] = GPTKeywordExtractionFormat |
| 7 | +All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai, |
| 8 | +with this module serving as a thin compatibility wrapper for existing code that |
| 9 | +imports from lightrag.llm.azure_openai. |
| 10 | +""" |
76 | 11 |
|
77 | | - openai_async_client = AsyncAzureOpenAI( |
78 | | - azure_endpoint=base_url, |
79 | | - azure_deployment=deployment, |
80 | | - api_key=api_key, |
81 | | - api_version=api_version, |
82 | | - timeout=timeout, |
83 | | - ) |
84 | | - messages = [] |
85 | | - if system_prompt: |
86 | | - messages.append({"role": "system", "content": system_prompt}) |
87 | | - if history_messages: |
88 | | - messages.extend(history_messages) |
89 | | - if prompt is not None: |
90 | | - messages.append({"role": "user", "content": prompt}) |
91 | | - |
92 | | - if "response_format" in kwargs: |
93 | | - response = await openai_async_client.chat.completions.parse( |
94 | | - model=model, messages=messages, **kwargs |
95 | | - ) |
96 | | - else: |
97 | | - response = await openai_async_client.chat.completions.create( |
98 | | - model=model, messages=messages, **kwargs |
99 | | - ) |
100 | | - |
101 | | - if hasattr(response, "__aiter__"): |
102 | | - |
103 | | - async def inner(): |
104 | | - async for chunk in response: |
105 | | - if len(chunk.choices) == 0: |
106 | | - continue |
107 | | - content = chunk.choices[0].delta.content |
108 | | - if content is None: |
109 | | - continue |
110 | | - if r"\u" in content: |
111 | | - content = safe_unicode_decode(content.encode("utf-8")) |
112 | | - yield content |
113 | | - |
114 | | - return inner() |
115 | | - else: |
116 | | - message = response.choices[0].message |
117 | | - |
118 | | - # Handle parsed responses (structured output via response_format) |
119 | | - # When using beta.chat.completions.parse(), the response is in message.parsed |
120 | | - if hasattr(message, "parsed") and message.parsed is not None: |
121 | | - # Serialize the parsed structured response to JSON |
122 | | - content = message.parsed.model_dump_json() |
123 | | - logger.debug("Using parsed structured response from API") |
124 | | - else: |
125 | | - # Handle regular content responses |
126 | | - content = message.content |
127 | | - if content and r"\u" in content: |
128 | | - content = safe_unicode_decode(content.encode("utf-8")) |
129 | | - |
130 | | - return content |
131 | | - |
132 | | - |
133 | | -async def azure_openai_complete( |
134 | | - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs |
135 | | -) -> str: |
136 | | - result = await azure_openai_complete_if_cache( |
137 | | - os.getenv("LLM_MODEL", "gpt-4o-mini"), |
138 | | - prompt, |
139 | | - system_prompt=system_prompt, |
140 | | - history_messages=history_messages, |
141 | | - keyword_extraction=keyword_extraction, |
142 | | - **kwargs, |
143 | | - ) |
144 | | - return result |
145 | | - |
146 | | - |
147 | | -@wrap_embedding_func_with_attrs(embedding_dim=1536) |
148 | | -@retry( |
149 | | - stop=stop_after_attempt(3), |
150 | | - wait=wait_exponential(multiplier=1, min=4, max=10), |
151 | | - retry=retry_if_exception_type( |
152 | | - (RateLimitError, APIConnectionError, APITimeoutError) |
153 | | - ), |
| 12 | +from lightrag.llm.openai import ( |
| 13 | + azure_openai_complete_if_cache, |
| 14 | + azure_openai_complete, |
| 15 | + azure_openai_embed, |
154 | 16 | ) |
155 | | -async def azure_openai_embed( |
156 | | - texts: list[str], |
157 | | - model: str | None = None, |
158 | | - base_url: str | None = None, |
159 | | - api_key: str | None = None, |
160 | | - api_version: str | None = None, |
161 | | -) -> np.ndarray: |
162 | | - deployment = ( |
163 | | - os.getenv("AZURE_EMBEDDING_DEPLOYMENT") |
164 | | - or model |
165 | | - or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") |
166 | | - ) |
167 | | - base_url = ( |
168 | | - base_url |
169 | | - or os.getenv("AZURE_EMBEDDING_ENDPOINT") |
170 | | - or os.getenv("EMBEDDING_BINDING_HOST") |
171 | | - ) |
172 | | - api_key = ( |
173 | | - api_key |
174 | | - or os.getenv("AZURE_EMBEDDING_API_KEY") |
175 | | - or os.getenv("EMBEDDING_BINDING_API_KEY") |
176 | | - ) |
177 | | - api_version = ( |
178 | | - api_version |
179 | | - or os.getenv("AZURE_EMBEDDING_API_VERSION") |
180 | | - or os.getenv("OPENAI_API_VERSION") |
181 | | - ) |
182 | | - |
183 | | - openai_async_client = AsyncAzureOpenAI( |
184 | | - azure_endpoint=base_url, |
185 | | - azure_deployment=deployment, |
186 | | - api_key=api_key, |
187 | | - api_version=api_version, |
188 | | - ) |
189 | 17 |
|
190 | | - response = await openai_async_client.embeddings.create( |
191 | | - model=model or deployment, input=texts, encoding_format="float" |
192 | | - ) |
193 | | - return np.array([dp.embedding for dp in response.data]) |
| 18 | +__all__ = [ |
| 19 | + "azure_openai_complete_if_cache", |
| 20 | + "azure_openai_complete", |
| 21 | + "azure_openai_embed", |
| 22 | +] |
0 commit comments