Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 17 additions & 188 deletions lightrag/llm/azure_openai.py
Original file line number Diff line number Diff line change
@@ -1,193 +1,22 @@
from collections.abc import Iterable
import os
import pipmaster as pm # Pipmaster for dynamic library install
"""
Azure OpenAI compatibility layer.

# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
This module provides backward compatibility by re-exporting Azure OpenAI functions
from the main openai module where the actual implementation resides.

from openai import (
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from openai.types.chat import ChatCompletionMessageParam

from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)

from lightrag.utils import (
wrap_embedding_func_with_attrs,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat

import numpy as np


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APIConnectionError)
),
)
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt: str | None = None,
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
keyword_extraction: bool = False,
**kwargs,
):
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
)
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
)
api_key = (
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)

kwargs.pop("hashing_kv", None)
timeout = kwargs.pop("timeout", None)

# Handle keyword extraction mode
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai,
with this module serving as a thin compatibility wrapper for existing code that
imports from lightrag.llm.azure_openai.
"""

openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
api_key=api_key,
api_version=api_version,
timeout=timeout,
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})

if "response_format" in kwargs:
response = await openai_async_client.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)

if hasattr(response, "__aiter__"):

async def inner():
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content

return inner()
else:
message = response.choices[0].message

# Handle parsed responses (structured output via response_format)
# When using beta.chat.completions.parse(), the response is in message.parsed
if hasattr(message, "parsed") and message.parsed is not None:
# Serialize the parsed structured response to JSON
content = message.parsed.model_dump_json()
logger.debug("Using parsed structured response from API")
else:
# Handle regular content responses
content = message.content
if content and r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))

return content


async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**kwargs,
)
return result


@wrap_embedding_func_with_attrs(embedding_dim=1536)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
from lightrag.llm.openai import (
azure_openai_complete_if_cache,
azure_openai_complete,
azure_openai_embed,
)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
or model
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
)
base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)

openai_async_client = AsyncAzureOpenAI(
azure_endpoint=base_url,
azure_deployment=deployment,
api_key=api_key,
api_version=api_version,
)

response = await openai_async_client.embeddings.create(
model=model or deployment, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
__all__ = [
"azure_openai_complete_if_cache",
"azure_openai_complete",
"azure_openai_embed",
]
Loading
Loading