|
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | | -from dataclasses import dataclass |
9 | | -from datetime import datetime |
10 | | -from typing import Any, Dict, List |
| 8 | +from collections import defaultdict |
11 | 9 |
|
12 | 10 | from chatkit.store import NotFoundError, Store |
13 | | -from chatkit.types import Attachment, Page, Thread, ThreadItem, ThreadMetadata |
| 11 | +from chatkit.types import Attachment, Page, ThreadItem, ThreadMetadata |
14 | 12 |
|
15 | 13 |
|
16 | | -@dataclass |
17 | | -class _ThreadState: |
18 | | - thread: ThreadMetadata |
19 | | - items: List[ThreadItem] |
| 14 | +class MemoryStore(Store[dict]): |
| 15 | + def __init__(self): |
| 16 | + self.threads: dict[str, ThreadMetadata] = {} |
| 17 | + self.items: dict[str, list[ThreadItem]] = defaultdict(list) |
20 | 18 |
|
21 | | - |
22 | | -class MemoryStore(Store[dict[str, Any]]): |
23 | | - """Simple in-memory store compatible with the ChatKit Store interface.""" |
24 | | - |
25 | | - def __init__(self) -> None: |
26 | | - self._threads: Dict[str, _ThreadState] = {} |
27 | | - # Attachments intentionally unsupported; use a real store that enforces auth. |
28 | | - |
29 | | - @staticmethod |
30 | | - def _coerce_thread_metadata(thread: ThreadMetadata | Thread) -> ThreadMetadata: |
31 | | - """Return thread metadata without any embedded items.""" |
32 | | - has_items = isinstance(thread, Thread) or "items" in getattr( |
33 | | - thread, "model_fields_set", set() |
34 | | - ) |
35 | | - if not has_items: |
36 | | - return thread.model_copy(deep=True) |
37 | | - |
38 | | - data = thread.model_dump() |
39 | | - data.pop("items", None) |
40 | | - return ThreadMetadata(**data).model_copy(deep=True) |
41 | | - |
42 | | - # -- Thread metadata ------------------------------------------------- |
43 | | - async def load_thread(self, thread_id: str, context: dict[str, Any]) -> ThreadMetadata: |
44 | | - state = self._threads.get(thread_id) |
45 | | - if not state: |
| 19 | + async def load_thread(self, thread_id: str, context: dict) -> ThreadMetadata: |
| 20 | + if thread_id not in self.threads: |
46 | 21 | raise NotFoundError(f"Thread {thread_id} not found") |
47 | | - return self._coerce_thread_metadata(state.thread) |
48 | | - |
49 | | - async def save_thread(self, thread: ThreadMetadata, context: dict[str, Any]) -> None: |
50 | | - metadata = self._coerce_thread_metadata(thread) |
51 | | - state = self._threads.get(thread.id) |
52 | | - if state: |
53 | | - state.thread = metadata |
54 | | - else: |
55 | | - self._threads[thread.id] = _ThreadState( |
56 | | - thread=metadata, |
57 | | - items=[], |
58 | | - ) |
| 22 | + return self.threads[thread_id] |
| 23 | + |
| 24 | + async def save_thread(self, thread: ThreadMetadata, context: dict) -> None: |
| 25 | + self.threads[thread.id] = thread |
59 | 26 |
|
60 | 27 | async def load_threads( |
61 | | - self, |
62 | | - limit: int, |
63 | | - after: str | None, |
64 | | - order: str, |
65 | | - context: dict[str, Any], |
| 28 | + self, limit: int, after: str | None, order: str, context: dict |
66 | 29 | ) -> Page[ThreadMetadata]: |
67 | | - threads = sorted( |
68 | | - (self._coerce_thread_metadata(state.thread) for state in self._threads.values()), |
69 | | - key=lambda t: t.created_at or datetime.min, |
70 | | - reverse=(order == "desc"), |
| 30 | + threads = list(self.threads.values()) |
| 31 | + return self._paginate( |
| 32 | + threads, after, limit, order, sort_key=lambda t: t.created_at, cursor_key=lambda t: t.id |
71 | 33 | ) |
72 | 34 |
|
73 | | - if after: |
74 | | - index_map = {thread.id: idx for idx, thread in enumerate(threads)} |
75 | | - start = index_map.get(after, -1) + 1 |
76 | | - else: |
77 | | - start = 0 |
78 | | - |
79 | | - slice_threads = threads[start : start + limit + 1] |
80 | | - has_more = len(slice_threads) > limit |
81 | | - slice_threads = slice_threads[:limit] |
82 | | - next_after = slice_threads[-1].id if has_more and slice_threads else None |
83 | | - return Page( |
84 | | - data=slice_threads, |
85 | | - has_more=has_more, |
86 | | - after=next_after, |
87 | | - ) |
88 | | - |
89 | | - async def delete_thread(self, thread_id: str, context: dict[str, Any]) -> None: |
90 | | - self._threads.pop(thread_id, None) |
91 | | - |
92 | | - # -- Thread items ---------------------------------------------------- |
93 | | - def _thread_state(self, thread_id: str) -> _ThreadState: |
94 | | - state = self._threads.get(thread_id) |
95 | | - if state is None: |
96 | | - state = _ThreadState( |
97 | | - thread=ThreadMetadata(id=thread_id, created_at=datetime.utcnow()), |
98 | | - items=[], |
99 | | - ) |
100 | | - self._threads[thread_id] = state |
101 | | - return state |
102 | | - |
103 | | - def _items(self, thread_id: str) -> List[ThreadItem]: |
104 | | - state = self._thread_state(thread_id) |
105 | | - return state.items |
106 | | - |
107 | 35 | async def load_thread_items( |
108 | | - self, |
109 | | - thread_id: str, |
110 | | - after: str | None, |
111 | | - limit: int, |
112 | | - order: str, |
113 | | - context: dict[str, Any], |
| 36 | + self, thread_id: str, after: str | None, limit: int, order: str, context: dict |
114 | 37 | ) -> Page[ThreadItem]: |
115 | | - items = [item.model_copy(deep=True) for item in self._items(thread_id)] |
116 | | - items.sort( |
117 | | - key=lambda item: getattr(item, "created_at", datetime.utcnow()), |
118 | | - reverse=(order == "desc"), |
| 38 | + items = self.items.get(thread_id, []) |
| 39 | + return self._paginate( |
| 40 | + items, after, limit, order, sort_key=lambda i: i.created_at, cursor_key=lambda i: i.id |
119 | 41 | ) |
120 | 42 |
|
121 | | - if after: |
122 | | - index_map = {item.id: idx for idx, item in enumerate(items)} |
123 | | - start = index_map.get(after, -1) + 1 |
124 | | - else: |
125 | | - start = 0 |
126 | | - |
127 | | - slice_items = items[start : start + limit + 1] |
128 | | - has_more = len(slice_items) > limit |
129 | | - slice_items = slice_items[:limit] |
130 | | - next_after = slice_items[-1].id if has_more and slice_items else None |
131 | | - return Page(data=slice_items, has_more=has_more, after=next_after) |
132 | | - |
133 | | - async def add_thread_item( |
134 | | - self, thread_id: str, item: ThreadItem, context: dict[str, Any] |
135 | | - ) -> None: |
136 | | - self._items(thread_id).append(item.model_copy(deep=True)) |
137 | | - |
138 | | - async def save_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None: |
139 | | - items = self._items(thread_id) |
| 43 | + async def add_thread_item(self, thread_id: str, item: ThreadItem, context: dict) -> None: |
| 44 | + self.items[thread_id].append(item) |
| 45 | + |
| 46 | + async def save_item(self, thread_id: str, item: ThreadItem, context: dict) -> None: |
| 47 | + items = self.items[thread_id] |
140 | 48 | for idx, existing in enumerate(items): |
141 | 49 | if existing.id == item.id: |
142 | | - items[idx] = item.model_copy(deep=True) |
| 50 | + items[idx] = item |
143 | 51 | return |
144 | | - items.append(item.model_copy(deep=True)) |
| 52 | + items.append(item) |
145 | 53 |
|
146 | | - async def load_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> ThreadItem: |
147 | | - for item in self._items(thread_id): |
| 54 | + async def load_item(self, thread_id: str, item_id: str, context: dict) -> ThreadItem: |
| 55 | + for item in self.items.get(thread_id, []): |
148 | 56 | if item.id == item_id: |
149 | | - return item.model_copy(deep=True) |
150 | | - raise NotFoundError(f"Item {item_id} not found") |
151 | | - |
152 | | - async def delete_thread_item( |
153 | | - self, thread_id: str, item_id: str, context: dict[str, Any] |
154 | | - ) -> None: |
155 | | - items = self._items(thread_id) |
156 | | - self._threads[thread_id].items = [item for item in items if item.id != item_id] |
157 | | - |
158 | | - # -- Files ----------------------------------------------------------- |
159 | | - # These methods are not currently used but required to be compatible with the Store interface. |
160 | | - |
161 | | - async def save_attachment( |
162 | | - self, |
163 | | - attachment: Attachment, |
164 | | - context: dict[str, Any], |
165 | | - ) -> None: |
166 | | - raise NotImplementedError( |
167 | | - "MemoryStore does not persist attachments. Provide a Store implementation " |
168 | | - "that enforces authentication and authorization before enabling uploads." |
169 | | - ) |
| 57 | + return item |
| 58 | + raise NotFoundError(f"Item {item_id} not found in thread {thread_id}") |
| 59 | + |
| 60 | + async def delete_thread(self, thread_id: str, context: dict) -> None: |
| 61 | + self.threads.pop(thread_id, None) |
| 62 | + self.items.pop(thread_id, None) |
| 63 | + |
| 64 | + async def delete_thread_item(self, thread_id: str, item_id: str, context: dict) -> None: |
| 65 | + self.items[thread_id] = [ |
| 66 | + item for item in self.items.get(thread_id, []) if item.id != item_id |
| 67 | + ] |
| 68 | + |
| 69 | + def _paginate( |
| 70 | + self, rows: list, after: str | None, limit: int, order: str, sort_key, cursor_key |
| 71 | + ): |
| 72 | + sorted_rows = sorted(rows, key=sort_key, reverse=order == "desc") |
| 73 | + start = 0 |
| 74 | + if after: |
| 75 | + for idx, row in enumerate(sorted_rows): |
| 76 | + if cursor_key(row) == after: |
| 77 | + start = idx + 1 |
| 78 | + break |
| 79 | + data = sorted_rows[start : start + limit] |
| 80 | + has_more = start + limit < len(sorted_rows) |
| 81 | + next_after = cursor_key(data[-1]) if has_more and data else None |
| 82 | + return Page(data=data, has_more=has_more, after=next_after) |
170 | 83 |
|
171 | | - async def load_attachment( |
172 | | - self, |
173 | | - attachment_id: str, |
174 | | - context: dict[str, Any], |
175 | | - ) -> Attachment: |
176 | | - raise NotImplementedError( |
177 | | - "MemoryStore does not load attachments. Provide a Store implementation " |
178 | | - "that enforces authentication and authorization before enabling uploads." |
179 | | - ) |
| 84 | + # Attachments are not implemented in the quickstart store |
180 | 85 |
|
181 | | - async def delete_attachment(self, attachment_id: str, context: dict[str, Any]) -> None: |
182 | | - raise NotImplementedError( |
183 | | - "MemoryStore does not delete attachments because they are never stored." |
184 | | - ) |
| 86 | + async def save_attachment(self, attachment: Attachment, context: dict) -> None: |
| 87 | + raise NotImplementedError() |
| 88 | + |
| 89 | + async def load_attachment(self, attachment_id: str, context: dict) -> Attachment: |
| 90 | + raise NotImplementedError() |
| 91 | + |
| 92 | + async def delete_attachment(self, attachment_id: str, context: dict) -> None: |
| 93 | + raise NotImplementedError() |
0 commit comments