|
51 | 51 | logger = camel_get_logger('base_model') |
52 | 52 |
|
53 | 53 |
|
| 54 | +class _StreamLogger: |
| 55 | + r"""Base for stream logging wrappers.""" |
| 56 | + |
| 57 | + def __init__(self, log_path: Optional[str], log_enabled: bool): |
| 58 | + self._log_path = log_path |
| 59 | + self._log_enabled = log_enabled |
| 60 | + self._id = self._model = self._content = "" |
| 61 | + self._finish_reason: Optional[str] = None |
| 62 | + self._usage: Optional[Dict[str, Any]] = None |
| 63 | + self._logged = False |
| 64 | + |
| 65 | + def _collect(self, chunk: ChatCompletionChunk) -> None: |
| 66 | + self._id = self._id or getattr(chunk, 'id', '') |
| 67 | + self._model = self._model or getattr(chunk, 'model', '') |
| 68 | + if chunk.usage: |
| 69 | + u = chunk.usage |
| 70 | + self._usage = ( |
| 71 | + u.model_dump() if hasattr(u, 'model_dump') else u.dict() |
| 72 | + ) |
| 73 | + if chunk.choices: |
| 74 | + choice = chunk.choices[0] |
| 75 | + if choice.delta and choice.delta.content: |
| 76 | + self._content += choice.delta.content |
| 77 | + if choice.finish_reason: |
| 78 | + self._finish_reason = choice.finish_reason |
| 79 | + |
| 80 | + def _log(self) -> None: |
| 81 | + if self._logged or not self._log_enabled or not self._log_path: |
| 82 | + return |
| 83 | + self._logged = True |
| 84 | + import json |
| 85 | + from datetime import datetime |
| 86 | + |
| 87 | + try: |
| 88 | + with open(self._log_path, "r+") as f: |
| 89 | + data = json.load(f) |
| 90 | + data["response_timestamp"] = datetime.now().isoformat() |
| 91 | + data["response"] = { |
| 92 | + "id": self._id, |
| 93 | + "model": self._model, |
| 94 | + "content": self._content, |
| 95 | + "finish_reason": self._finish_reason, |
| 96 | + "usage": self._usage, |
| 97 | + "streaming": True, |
| 98 | + } |
| 99 | + f.seek(0) |
| 100 | + json.dump(data, f, indent=4) |
| 101 | + f.truncate() |
| 102 | + except Exception: |
| 103 | + pass |
| 104 | + |
| 105 | + |
| 106 | +class _SyncStreamWrapper(_StreamLogger): |
| 107 | + r"""Sync stream wrapper with logging.""" |
| 108 | + |
| 109 | + def __init__( |
| 110 | + self, |
| 111 | + stream: Stream[ChatCompletionChunk], |
| 112 | + log_path: Optional[str], |
| 113 | + log_enabled: bool, |
| 114 | + ): |
| 115 | + super().__init__(log_path, log_enabled) |
| 116 | + self._stream = stream |
| 117 | + |
| 118 | + def __iter__(self): |
| 119 | + return self |
| 120 | + |
| 121 | + def __next__(self) -> ChatCompletionChunk: |
| 122 | + try: |
| 123 | + chunk = next(self._stream) |
| 124 | + self._collect(chunk) |
| 125 | + return chunk |
| 126 | + except StopIteration: |
| 127 | + self._log() |
| 128 | + raise |
| 129 | + |
| 130 | + def __enter__(self): |
| 131 | + return self |
| 132 | + |
| 133 | + def __exit__(self, *_): |
| 134 | + return False |
| 135 | + |
| 136 | + def __del__(self): |
| 137 | + self._log() |
| 138 | + |
| 139 | + |
| 140 | +class _AsyncStreamWrapper(_StreamLogger): |
| 141 | + r"""Async stream wrapper with logging.""" |
| 142 | + |
| 143 | + def __init__( |
| 144 | + self, |
| 145 | + stream: AsyncStream[ChatCompletionChunk], |
| 146 | + log_path: Optional[str], |
| 147 | + log_enabled: bool, |
| 148 | + ): |
| 149 | + super().__init__(log_path, log_enabled) |
| 150 | + self._stream = stream |
| 151 | + |
| 152 | + def __aiter__(self): |
| 153 | + return self |
| 154 | + |
| 155 | + async def __anext__(self) -> ChatCompletionChunk: |
| 156 | + try: |
| 157 | + chunk = await self._stream.__anext__() |
| 158 | + self._collect(chunk) |
| 159 | + return chunk |
| 160 | + except StopAsyncIteration: |
| 161 | + self._log() |
| 162 | + raise |
| 163 | + |
| 164 | + async def __aenter__(self): |
| 165 | + return self |
| 166 | + |
| 167 | + async def __aexit__(self, *_): |
| 168 | + return False |
| 169 | + |
| 170 | + def __del__(self): |
| 171 | + self._log() |
| 172 | + |
| 173 | + |
54 | 174 | class ModelBackendMeta(abc.ABCMeta): |
55 | 175 | r"""Metaclass that automatically preprocesses messages in run method. |
56 | 176 |
|
@@ -428,10 +548,13 @@ def run( |
428 | 548 | result = self._run(messages, response_format, tools) |
429 | 549 | logger.info("Result: %s", result) |
430 | 550 |
|
431 | | - # Log the response if logging is enabled |
| 551 | + # For streaming responses, wrap with logging; otherwise log immediately |
| 552 | + if isinstance(result, Stream): |
| 553 | + return _SyncStreamWrapper( # type: ignore[return-value] |
| 554 | + result, log_path, self._log_enabled |
| 555 | + ) |
432 | 556 | if log_path: |
433 | 557 | self._log_response(log_path, result) |
434 | | - |
435 | 558 | return result |
436 | 559 |
|
437 | 560 | @observe() |
@@ -480,10 +603,13 @@ async def arun( |
480 | 603 | result = await self._arun(messages, response_format, tools) |
481 | 604 | logger.info("Result: %s", result) |
482 | 605 |
|
483 | | - # Log the response if logging is enabled |
| 606 | + # For streaming responses, wrap with logging; otherwise log immediately |
| 607 | + if isinstance(result, AsyncStream): |
| 608 | + return _AsyncStreamWrapper( # type: ignore[return-value] |
| 609 | + result, log_path, self._log_enabled |
| 610 | + ) |
484 | 611 | if log_path: |
485 | 612 | self._log_response(log_path, result) |
486 | | - |
487 | 613 | return result |
488 | 614 |
|
489 | 615 | def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int: |
|
0 commit comments