Skip to content

Commit ac9f257

Browse files
committed
Improve Azure OpenAI wrapper functions with full parameter support
• Add missing parameters to wrappers • Update docstrings for clarity • Ensure API consistency • Fix parameter forwarding • Maintain backward compatibility
1 parent 45f4f82 commit ac9f257

File tree

1 file changed

+48
-13
lines changed

1 file changed

+48
-13
lines changed

lightrag/llm/openai.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,33 @@ async def openai_complete_if_cache(
205205
6. For non-streaming: COT content is prepended to regular content with <think> tags.
206206
207207
Args:
208-
model: The OpenAI model to use.
208+
model: The OpenAI model to use. For Azure, this can be the deployment name.
209209
prompt: The prompt to complete.
210210
system_prompt: Optional system prompt to include.
211211
history_messages: Optional list of previous messages in the conversation.
212-
base_url: Optional base URL for the OpenAI API.
213-
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
214-
token_tracker: Optional token usage tracker for monitoring API usage.
215212
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
213+
base_url: Optional base URL for the OpenAI API. For Azure, this should be the
214+
Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
215+
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment
216+
variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None.
217+
token_tracker: Optional token usage tracker for monitoring API usage.
216218
stream: Whether to stream the response. Default is False.
217219
timeout: Request timeout in seconds. Default is None.
218220
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
219221
special response formatting for keyword extraction. Default is False.
222+
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
223+
When True, creates an AsyncAzureOpenAI client. Default is False.
224+
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
225+
If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable.
226+
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
227+
when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION
228+
environment variable.
220229
**kwargs: Additional keyword arguments to pass to the OpenAI API.
221230
Special kwargs:
222231
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
223232
These will be passed to the client constructor but will be overridden by
224-
explicit parameters (api_key, base_url).
233+
explicit parameters (api_key, base_url). Supports proxy configuration,
234+
custom headers, retry policies, etc.
225235
226236
Returns:
227237
The completed text (with integrated COT content if available) or an async iterator
@@ -684,21 +694,34 @@ async def openai_embed(
684694
) -> np.ndarray:
685695
"""Generate embeddings for a list of texts using OpenAI's API.
686696
697+
This function supports both standard OpenAI and Azure OpenAI services.
698+
687699
Args:
688700
texts: List of texts to embed.
689-
model: The OpenAI embedding model to use.
690-
base_url: Optional base URL for the OpenAI API.
691-
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
701+
model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small").
702+
For Azure, this can be the deployment name.
703+
base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint.
704+
For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
705+
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None.
706+
For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None.
692707
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
693708
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
694709
Do NOT manually pass this parameter when calling the function directly.
695710
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
696711
Manually passing a different value will trigger a warning and be ignored.
697712
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
698-
client_configs: Additional configuration options for the AsyncOpenAI client.
713+
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
699714
These will override any default configurations but will be overridden by
700-
explicit parameters (api_key, base_url).
715+
explicit parameters (api_key, base_url). Supports proxy configuration,
716+
custom headers, retry policies, etc.
701717
token_tracker: Optional token usage tracker for monitoring API usage.
718+
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
719+
When True, creates an AsyncAzureOpenAI client. Default is False.
720+
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
721+
If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable.
722+
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
723+
when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION
724+
environment variable.
702725
703726
Returns:
704727
A numpy array of embeddings, one per input text.
@@ -759,6 +782,9 @@ async def azure_openai_complete_if_cache(
759782
enable_cot: bool = False,
760783
base_url: str | None = None,
761784
api_key: str | None = None,
785+
token_tracker: Any | None = None,
786+
stream: bool | None = None,
787+
timeout: int | None = None,
762788
api_version: str | None = None,
763789
keyword_extraction: bool = False,
764790
**kwargs,
@@ -767,6 +793,9 @@ async def azure_openai_complete_if_cache(
767793
768794
This function provides backward compatibility by wrapping the unified
769795
openai_complete_if_cache implementation with Azure-specific parameter handling.
796+
797+
All parameters from the underlying openai_complete_if_cache are exposed to ensure
798+
full feature parity and API consistency.
770799
"""
771800
# Handle Azure-specific environment variables and parameters
772801
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
@@ -782,9 +811,6 @@ async def azure_openai_complete_if_cache(
782811
or os.getenv("OPENAI_API_VERSION")
783812
)
784813

785-
# Pop timeout from kwargs if present (will be handled by openai_complete_if_cache)
786-
timeout = kwargs.pop("timeout", None)
787-
788814
# Call the unified implementation with Azure-specific parameters
789815
return await openai_complete_if_cache(
790816
model=model,
@@ -794,6 +820,8 @@ async def azure_openai_complete_if_cache(
794820
enable_cot=enable_cot,
795821
base_url=base_url,
796822
api_key=api_key,
823+
token_tracker=token_tracker,
824+
stream=stream,
797825
timeout=timeout,
798826
use_azure=True,
799827
azure_deployment=deployment,
@@ -833,13 +861,18 @@ async def azure_openai_embed(
833861
model: str | None = None,
834862
base_url: str | None = None,
835863
api_key: str | None = None,
864+
token_tracker: Any | None = None,
865+
client_configs: dict[str, Any] | None = None,
836866
api_version: str | None = None,
837867
) -> np.ndarray:
838868
"""Azure OpenAI embedding wrapper function.
839869
840870
This function provides backward compatibility by wrapping the unified
841871
openai_embed implementation with Azure-specific parameter handling.
842872
873+
All parameters from the underlying openai_embed are exposed to ensure
874+
full feature parity and API consistency.
875+
843876
IMPORTANT - Decorator Usage:
844877
845878
1. This function is decorated with @wrap_embedding_func_with_attrs to provide
@@ -898,6 +931,8 @@ async def azure_openai_embed(
898931
model=model or deployment,
899932
base_url=base_url,
900933
api_key=api_key,
934+
token_tracker=token_tracker,
935+
client_configs=client_configs,
901936
use_azure=True,
902937
azure_deployment=deployment,
903938
api_version=api_version,

0 commit comments

Comments
 (0)