Skip to content

Commit a162725

Browse files
authored
simplify example in-memory store implementation (#42)
1 parent dc51b33 commit a162725

File tree

4 files changed

+235
-582
lines changed

4 files changed

+235
-582
lines changed

examples/cat-lounge/backend/app/memory_store.py

Lines changed: 64 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -5,180 +5,89 @@
55

66
from __future__ import annotations
77

8-
from dataclasses import dataclass
9-
from datetime import datetime
10-
from typing import Any, Dict, List
8+
from collections import defaultdict
119

1210
from chatkit.store import NotFoundError, Store
13-
from chatkit.types import Attachment, Page, Thread, ThreadItem, ThreadMetadata
11+
from chatkit.types import Attachment, Page, ThreadItem, ThreadMetadata
1412

1513

16-
@dataclass
17-
class _ThreadState:
18-
thread: ThreadMetadata
19-
items: List[ThreadItem]
14+
class MemoryStore(Store[dict]):
15+
def __init__(self):
16+
self.threads: dict[str, ThreadMetadata] = {}
17+
self.items: dict[str, list[ThreadItem]] = defaultdict(list)
2018

21-
22-
class MemoryStore(Store[dict[str, Any]]):
23-
"""Simple in-memory store compatible with the ChatKit Store interface."""
24-
25-
def __init__(self) -> None:
26-
self._threads: Dict[str, _ThreadState] = {}
27-
# Attachments intentionally unsupported; use a real store that enforces auth.
28-
29-
@staticmethod
30-
def _coerce_thread_metadata(thread: ThreadMetadata | Thread) -> ThreadMetadata:
31-
"""Return thread metadata without any embedded items."""
32-
has_items = isinstance(thread, Thread) or "items" in getattr(
33-
thread, "model_fields_set", set()
34-
)
35-
if not has_items:
36-
return thread.model_copy(deep=True)
37-
38-
data = thread.model_dump()
39-
data.pop("items", None)
40-
return ThreadMetadata(**data).model_copy(deep=True)
41-
42-
# -- Thread metadata -------------------------------------------------
43-
async def load_thread(self, thread_id: str, context: dict[str, Any]) -> ThreadMetadata:
44-
state = self._threads.get(thread_id)
45-
if not state:
19+
async def load_thread(self, thread_id: str, context: dict) -> ThreadMetadata:
20+
if thread_id not in self.threads:
4621
raise NotFoundError(f"Thread {thread_id} not found")
47-
return self._coerce_thread_metadata(state.thread)
48-
49-
async def save_thread(self, thread: ThreadMetadata, context: dict[str, Any]) -> None:
50-
metadata = self._coerce_thread_metadata(thread)
51-
state = self._threads.get(thread.id)
52-
if state:
53-
state.thread = metadata
54-
else:
55-
self._threads[thread.id] = _ThreadState(
56-
thread=metadata,
57-
items=[],
58-
)
22+
return self.threads[thread_id]
23+
24+
async def save_thread(self, thread: ThreadMetadata, context: dict) -> None:
25+
self.threads[thread.id] = thread
5926

6027
async def load_threads(
61-
self,
62-
limit: int,
63-
after: str | None,
64-
order: str,
65-
context: dict[str, Any],
28+
self, limit: int, after: str | None, order: str, context: dict
6629
) -> Page[ThreadMetadata]:
67-
threads = sorted(
68-
(self._coerce_thread_metadata(state.thread) for state in self._threads.values()),
69-
key=lambda t: t.created_at or datetime.min,
70-
reverse=(order == "desc"),
30+
threads = list(self.threads.values())
31+
return self._paginate(
32+
threads, after, limit, order, sort_key=lambda t: t.created_at, cursor_key=lambda t: t.id
7133
)
7234

73-
if after:
74-
index_map = {thread.id: idx for idx, thread in enumerate(threads)}
75-
start = index_map.get(after, -1) + 1
76-
else:
77-
start = 0
78-
79-
slice_threads = threads[start : start + limit + 1]
80-
has_more = len(slice_threads) > limit
81-
slice_threads = slice_threads[:limit]
82-
next_after = slice_threads[-1].id if has_more and slice_threads else None
83-
return Page(
84-
data=slice_threads,
85-
has_more=has_more,
86-
after=next_after,
87-
)
88-
89-
async def delete_thread(self, thread_id: str, context: dict[str, Any]) -> None:
90-
self._threads.pop(thread_id, None)
91-
92-
# -- Thread items ----------------------------------------------------
93-
def _thread_state(self, thread_id: str) -> _ThreadState:
94-
state = self._threads.get(thread_id)
95-
if state is None:
96-
state = _ThreadState(
97-
thread=ThreadMetadata(id=thread_id, created_at=datetime.utcnow()),
98-
items=[],
99-
)
100-
self._threads[thread_id] = state
101-
return state
102-
103-
def _items(self, thread_id: str) -> List[ThreadItem]:
104-
state = self._thread_state(thread_id)
105-
return state.items
106-
10735
async def load_thread_items(
108-
self,
109-
thread_id: str,
110-
after: str | None,
111-
limit: int,
112-
order: str,
113-
context: dict[str, Any],
36+
self, thread_id: str, after: str | None, limit: int, order: str, context: dict
11437
) -> Page[ThreadItem]:
115-
items = [item.model_copy(deep=True) for item in self._items(thread_id)]
116-
items.sort(
117-
key=lambda item: getattr(item, "created_at", datetime.utcnow()),
118-
reverse=(order == "desc"),
38+
items = self.items.get(thread_id, [])
39+
return self._paginate(
40+
items, after, limit, order, sort_key=lambda i: i.created_at, cursor_key=lambda i: i.id
11941
)
12042

121-
if after:
122-
index_map = {item.id: idx for idx, item in enumerate(items)}
123-
start = index_map.get(after, -1) + 1
124-
else:
125-
start = 0
126-
127-
slice_items = items[start : start + limit + 1]
128-
has_more = len(slice_items) > limit
129-
slice_items = slice_items[:limit]
130-
next_after = slice_items[-1].id if has_more and slice_items else None
131-
return Page(data=slice_items, has_more=has_more, after=next_after)
132-
133-
async def add_thread_item(
134-
self, thread_id: str, item: ThreadItem, context: dict[str, Any]
135-
) -> None:
136-
self._items(thread_id).append(item.model_copy(deep=True))
137-
138-
async def save_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None:
139-
items = self._items(thread_id)
43+
async def add_thread_item(self, thread_id: str, item: ThreadItem, context: dict) -> None:
44+
self.items[thread_id].append(item)
45+
46+
async def save_item(self, thread_id: str, item: ThreadItem, context: dict) -> None:
47+
items = self.items[thread_id]
14048
for idx, existing in enumerate(items):
14149
if existing.id == item.id:
142-
items[idx] = item.model_copy(deep=True)
50+
items[idx] = item
14351
return
144-
items.append(item.model_copy(deep=True))
52+
items.append(item)
14553

146-
async def load_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> ThreadItem:
147-
for item in self._items(thread_id):
54+
async def load_item(self, thread_id: str, item_id: str, context: dict) -> ThreadItem:
55+
for item in self.items.get(thread_id, []):
14856
if item.id == item_id:
149-
return item.model_copy(deep=True)
150-
raise NotFoundError(f"Item {item_id} not found")
151-
152-
async def delete_thread_item(
153-
self, thread_id: str, item_id: str, context: dict[str, Any]
154-
) -> None:
155-
items = self._items(thread_id)
156-
self._threads[thread_id].items = [item for item in items if item.id != item_id]
157-
158-
# -- Files -----------------------------------------------------------
159-
# These methods are not currently used but required to be compatible with the Store interface.
160-
161-
async def save_attachment(
162-
self,
163-
attachment: Attachment,
164-
context: dict[str, Any],
165-
) -> None:
166-
raise NotImplementedError(
167-
"MemoryStore does not persist attachments. Provide a Store implementation "
168-
"that enforces authentication and authorization before enabling uploads."
169-
)
57+
return item
58+
raise NotFoundError(f"Item {item_id} not found in thread {thread_id}")
59+
60+
async def delete_thread(self, thread_id: str, context: dict) -> None:
61+
self.threads.pop(thread_id, None)
62+
self.items.pop(thread_id, None)
63+
64+
async def delete_thread_item(self, thread_id: str, item_id: str, context: dict) -> None:
65+
self.items[thread_id] = [
66+
item for item in self.items.get(thread_id, []) if item.id != item_id
67+
]
68+
69+
def _paginate(
70+
self, rows: list, after: str | None, limit: int, order: str, sort_key, cursor_key
71+
):
72+
sorted_rows = sorted(rows, key=sort_key, reverse=order == "desc")
73+
start = 0
74+
if after:
75+
for idx, row in enumerate(sorted_rows):
76+
if cursor_key(row) == after:
77+
start = idx + 1
78+
break
79+
data = sorted_rows[start : start + limit]
80+
has_more = start + limit < len(sorted_rows)
81+
next_after = cursor_key(data[-1]) if has_more and data else None
82+
return Page(data=data, has_more=has_more, after=next_after)
17083

171-
async def load_attachment(
172-
self,
173-
attachment_id: str,
174-
context: dict[str, Any],
175-
) -> Attachment:
176-
raise NotImplementedError(
177-
"MemoryStore does not load attachments. Provide a Store implementation "
178-
"that enforces authentication and authorization before enabling uploads."
179-
)
84+
# Attachments are not implemented in the quickstart store
18085

181-
async def delete_attachment(self, attachment_id: str, context: dict[str, Any]) -> None:
182-
raise NotImplementedError(
183-
"MemoryStore does not delete attachments because they are never stored."
184-
)
86+
async def save_attachment(self, attachment: Attachment, context: dict) -> None:
87+
raise NotImplementedError()
88+
89+
async def load_attachment(self, attachment_id: str, context: dict) -> Attachment:
90+
raise NotImplementedError()
91+
92+
async def delete_attachment(self, attachment_id: str, context: dict) -> None:
93+
raise NotImplementedError()

0 commit comments

Comments
 (0)