Skip to content

Commit e1b2042

Browse files
victorjmarinVictor MarinJoJoTheBizarre
authored
fix(runs): enforce thread ownership before creating runs (#337)
Fixes #336. Adds SQL-level user_id ownership check at top of POST /threads/{id}/runs, /runs/stream, and /runs/wait. Prevents cross-user run injection where attacker A with valid JWT could execute runs against user B's threads, read B's checkpoint state via run output, and inject messages into B's conversation history. Returns 404 (not 403) to avoid leaking thread existence. Co-Authored-By: Victor Marin <victor.marin@mercanis.com> Co-Authored-By: Jawhar Djebbi <jawhardjebbi@gmail.com>
1 parent d7d80c8 commit e1b2042

6 files changed

Lines changed: 320 additions & 18 deletions

File tree

libs/aegra-api/src/aegra_api/api/runs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from aegra_api.core.auth_deps import auth_dependency, get_current_user
1616
from aegra_api.core.auth_handlers import build_auth_context, handle_event
1717
from aegra_api.core.orm import Run as RunORM
18+
from aegra_api.core.orm import Thread as ThreadORM
1819
from aegra_api.core.orm import _get_session_maker, get_session
1920
from aegra_api.core.sse import create_end_event, get_sse_headers
2021
from aegra_api.models import Run, RunCreate, RunStatus, User
@@ -50,6 +51,12 @@ async def create_run(
5051
endpoint to follow progress. Provide either `input` or `command` (for
5152
human-in-the-loop resumption) but not both.
5253
"""
54+
existing_thread = await session.scalar(
55+
select(ThreadORM).where(ThreadORM.thread_id == thread_id)
56+
)
57+
if existing_thread and existing_thread.user_id != user.identity:
58+
raise HTTPException(404, f"Thread '{thread_id}' not found")
59+
5360
# Authorization check (create_run action on threads resource)
5461
ctx = build_auth_context(user, "threads", "create_run")
5562
value = {**request.model_dump(), "thread_id": thread_id}
@@ -92,6 +99,12 @@ async def create_and_stream_run(
9299
after the client disconnects (default is `"cancel"`). Use `stream_mode`
93100
to control which event types are emitted.
94101
"""
102+
existing_thread = await session.scalar(
103+
select(ThreadORM).where(ThreadORM.thread_id == thread_id)
104+
)
105+
if existing_thread and existing_thread.user_id != user.identity:
106+
raise HTTPException(404, f"Thread '{thread_id}' not found")
107+
95108
run_id, run, _job = await _prepare_run(session, thread_id, request, user, initial_status="pending")
96109

97110
# Default to cancel on disconnect - this matches user expectation that clicking
@@ -311,6 +324,12 @@ async def wait_for_run(
311324

312325
# Session block: all pre-execution DB work (validate, create run, submit)
313326
async with maker() as session:
327+
existing_thread = await session.scalar(
328+
select(ThreadORM).where(ThreadORM.thread_id == thread_id)
329+
)
330+
if existing_thread and existing_thread.user_id != user.identity:
331+
raise HTTPException(404, f"Thread '{thread_id}' not found")
332+
314333
run_id, _run, _job = await _prepare_run(session, thread_id, request, user, initial_status="pending")
315334

316335
# No pool connection held from here — safe for long waits
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""E2E tests verifying thread user isolation when authentication is enabled.
2+
3+
Claim under test:
4+
When authentication is enabled, threads are automatically scoped to the
5+
authenticated user. Users can only see and interact with their own threads.
6+
7+
⚠️ MANUAL TESTS - These are skipped by default. Run with: pytest -m manual_auth
8+
9+
Requires a running Aegra server with auth enabled. See README.md for setup.
10+
11+
Run:
12+
pytest tests/e2e/manual_auth_tests/test_thread_user_isolation_e2e.py -v -m manual_auth
13+
"""
14+
15+
import uuid
16+
17+
import httpx
18+
import pytest
19+
20+
from aegra_api.settings import settings
21+
from tests.e2e._utils import elog
22+
23+
24+
def get_server_url() -> str:
25+
return settings.app.SERVER_URL
26+
27+
28+
def get_auth_headers(user_id: str, role: str = "user", team_id: str = "team1") -> dict[str, str]:
29+
token = f"mock-jwt-{user_id}-{role}-{team_id}"
30+
return {"Authorization": f"Bearer {token}"}
31+
32+
33+
def get_client_with_auth(user_id: str, role: str = "user", team_id: str = "team1"):
34+
from langgraph_sdk import get_client
35+
36+
token = f"mock-jwt-{user_id}-{role}-{team_id}"
37+
return get_client(url=get_server_url(), headers={"Authorization": f"Bearer {token}"})
38+
39+
40+
@pytest.mark.e2e
41+
@pytest.mark.manual_auth
42+
class TestThreadOwnership:
43+
"""Threads are created and stored under the authenticated user's identity."""
44+
45+
@pytest.mark.asyncio
46+
async def test_created_thread_is_owned_by_creator(self) -> None:
47+
"""GET /threads/<id> returns the thread to its creator."""
48+
client = get_client_with_auth("alice")
49+
thread = await client.threads.create(metadata={"isolation_test": "ownership"})
50+
thread_id = thread["thread_id"]
51+
elog("Created thread", thread)
52+
53+
fetched = await client.threads.get(thread_id)
54+
assert fetched["thread_id"] == thread_id, "Creator should be able to fetch their own thread"
55+
56+
@pytest.mark.asyncio
57+
async def test_other_user_cannot_get_thread(self) -> None:
58+
"""GET /threads/<id> returns 404 when a different user requests it."""
59+
alice_client = get_client_with_auth("alice")
60+
thread = await alice_client.threads.create(metadata={"isolation_test": "cross-user-get"})
61+
thread_id = thread["thread_id"]
62+
elog("Alice created thread", {"thread_id": thread_id})
63+
64+
bob_headers = get_auth_headers("bob")
65+
async with httpx.AsyncClient(
66+
base_url=get_server_url(), headers=bob_headers, timeout=30.0
67+
) as http:
68+
resp = await http.get(f"/threads/{thread_id}")
69+
70+
elog("Bob GET Alice's thread", {"status": resp.status_code})
71+
assert resp.status_code == 404, (
72+
f"Expected 404 when bob requests alice's thread, got {resp.status_code}: {resp.text}"
73+
)
74+
75+
76+
@pytest.mark.e2e
77+
@pytest.mark.manual_auth
78+
class TestThreadSearch:
79+
"""Thread search/list only returns threads belonging to the requesting user."""
80+
81+
@pytest.mark.asyncio
82+
async def test_search_returns_only_own_threads(self) -> None:
83+
"""POST /threads/search returns threads for the requesting user only."""
84+
tag = f"isolation-search-{uuid.uuid4().hex[:8]}"
85+
86+
alice_client = get_client_with_auth("alice")
87+
bob_client = get_client_with_auth("bob")
88+
89+
alice_thread = await alice_client.threads.create(metadata={"isolation_tag": tag})
90+
bob_thread = await bob_client.threads.create(metadata={"isolation_tag": tag})
91+
elog("Seeded threads", {"alice": alice_thread["thread_id"], "bob": bob_thread["thread_id"]})
92+
93+
alice_headers = get_auth_headers("alice")
94+
async with httpx.AsyncClient(
95+
base_url=get_server_url(), headers=alice_headers, timeout=30.0
96+
) as http:
97+
resp = await http.post(
98+
"/threads/search",
99+
json={"metadata": {"isolation_tag": tag}, "limit": 100},
100+
)
101+
assert resp.status_code == 200, resp.text
102+
thread_ids = {t["thread_id"] for t in resp.json()}
103+
elog("Alice search results", sorted(thread_ids))
104+
105+
assert alice_thread["thread_id"] in thread_ids, "Alice should see her own thread"
106+
assert bob_thread["thread_id"] not in thread_ids, "Alice must not see Bob's thread"
107+
108+
@pytest.mark.asyncio
109+
async def test_list_endpoint_returns_only_own_threads(self) -> None:
110+
"""GET /threads returns only threads owned by the requesting user."""
111+
tag = f"isolation-list-{uuid.uuid4().hex[:8]}"
112+
113+
alice_client = get_client_with_auth("alice")
114+
bob_client = get_client_with_auth("bob")
115+
116+
alice_thread = await alice_client.threads.create(metadata={"isolation_tag": tag})
117+
bob_thread = await bob_client.threads.create(metadata={"isolation_tag": tag})
118+
119+
alice_headers = get_auth_headers("alice")
120+
async with httpx.AsyncClient(
121+
base_url=get_server_url(), headers=alice_headers, timeout=30.0
122+
) as http:
123+
resp = await http.get("/threads", params={"limit": 1000})
124+
assert resp.status_code == 200, resp.text
125+
126+
data = resp.json()
127+
thread_ids = {t["thread_id"] for t in (data if isinstance(data, list) else data.get("threads", []))}
128+
elog("Alice list results", {"count": len(thread_ids)})
129+
130+
assert alice_thread["thread_id"] in thread_ids, "Alice should see her own thread"
131+
assert bob_thread["thread_id"] not in thread_ids, "Alice must not see Bob's thread"
132+
133+
134+
@pytest.mark.e2e
135+
@pytest.mark.manual_auth
136+
class TestThreadMutationIsolation:
137+
"""Users cannot mutate threads that belong to another user."""
138+
139+
@pytest.mark.asyncio
140+
async def test_other_user_cannot_update_thread(self) -> None:
141+
"""PATCH /threads/<id> returns 404 when a different user attempts to update."""
142+
alice_client = get_client_with_auth("alice")
143+
thread = await alice_client.threads.create(metadata={"isolation_test": "update"})
144+
thread_id = thread["thread_id"]
145+
146+
bob_headers = get_auth_headers("bob")
147+
async with httpx.AsyncClient(
148+
base_url=get_server_url(), headers=bob_headers, timeout=30.0
149+
) as http:
150+
resp = await http.patch(
151+
f"/threads/{thread_id}",
152+
json={"metadata": {"hijacked": True}},
153+
)
154+
elog("Bob PATCH Alice's thread", {"status": resp.status_code})
155+
assert resp.status_code == 404, (
156+
f"Expected 404 when bob patches alice's thread, got {resp.status_code}: {resp.text}"
157+
)
158+
159+
@pytest.mark.asyncio
160+
async def test_other_user_cannot_delete_thread(self) -> None:
161+
"""DELETE /threads/<id> returns 404 when a different user attempts to delete."""
162+
alice_client = get_client_with_auth("alice")
163+
thread = await alice_client.threads.create(metadata={"isolation_test": "delete"})
164+
thread_id = thread["thread_id"]
165+
166+
bob_headers = get_auth_headers("bob")
167+
async with httpx.AsyncClient(
168+
base_url=get_server_url(), headers=bob_headers, timeout=30.0
169+
) as http:
170+
resp = await http.delete(f"/threads/{thread_id}")
171+
elog("Bob DELETE Alice's thread", {"status": resp.status_code})
172+
assert resp.status_code == 404, (
173+
f"Expected 404 when bob deletes alice's thread, got {resp.status_code}: {resp.text}"
174+
)
175+
176+
# Thread still accessible by Alice after Bob's failed delete attempt
177+
fetched = await alice_client.threads.get(thread_id)
178+
assert fetched["thread_id"] == thread_id, "Thread must still exist after unauthorized delete attempt"
179+
180+
@pytest.mark.asyncio
181+
async def test_other_user_cannot_add_run_to_thread(self) -> None:
182+
"""POST /threads/<id>/runs returns 404 when a different user attempts to create a run."""
183+
alice_client = get_client_with_auth("alice")
184+
thread = await alice_client.threads.create(metadata={"isolation_test": "run"})
185+
thread_id = thread["thread_id"]
186+
187+
bob_headers = get_auth_headers("bob")
188+
async with httpx.AsyncClient(
189+
base_url=get_server_url(), headers=bob_headers, timeout=30.0
190+
) as http:
191+
resp = await http.post(
192+
f"/threads/{thread_id}/runs",
193+
json={"assistant_id": "agent", "input": {"messages": [{"role": "human", "content": "hi"}]}},
194+
)
195+
elog("Bob POST run to Alice's thread", {"status": resp.status_code})
196+
assert resp.status_code == 404, (
197+
f"Expected 404 when bob tries to run against alice's thread, got {resp.status_code}: {resp.text}"
198+
)
199+
200+
@pytest.mark.asyncio
201+
async def test_other_user_cannot_stream_run_on_thread(self) -> None:
202+
"""POST /threads/<id>/runs/stream returns 404 when a different user attempts to stream a run."""
203+
alice_client = get_client_with_auth("alice")
204+
thread = await alice_client.threads.create(metadata={"isolation_test": "stream-run"})
205+
thread_id = thread["thread_id"]
206+
207+
bob_headers = get_auth_headers("bob")
208+
async with httpx.AsyncClient(
209+
base_url=get_server_url(), headers=bob_headers, timeout=30.0
210+
) as http:
211+
resp = await http.post(
212+
f"/threads/{thread_id}/runs/stream",
213+
json={"assistant_id": "agent", "input": {"messages": [{"role": "human", "content": "hi"}]}},
214+
)
215+
elog("Bob POST stream run to Alice's thread", {"status": resp.status_code})
216+
assert resp.status_code == 404, (
217+
f"Expected 404 when bob streams against alice's thread, got {resp.status_code}: {resp.text}"
218+
)
219+
220+
@pytest.mark.asyncio
221+
async def test_other_user_cannot_wait_run_on_thread(self) -> None:
222+
"""POST /threads/<id>/runs/wait returns 404 when a different user attempts to create a run."""
223+
alice_client = get_client_with_auth("alice")
224+
thread = await alice_client.threads.create(metadata={"isolation_test": "wait-run"})
225+
thread_id = thread["thread_id"]
226+
227+
bob_headers = get_auth_headers("bob")
228+
async with httpx.AsyncClient(
229+
base_url=get_server_url(), headers=bob_headers, timeout=30.0
230+
) as http:
231+
resp = await http.post(
232+
f"/threads/{thread_id}/runs/wait",
233+
json={"assistant_id": "agent", "input": {"messages": [{"role": "human", "content": "hi"}]}},
234+
)
235+
elog("Bob POST wait run to Alice's thread", {"status": resp.status_code})
236+
assert resp.status_code == 404, (
237+
f"Expected 404 when bob waits against alice's thread, got {resp.status_code}: {resp.text}"
238+
)
239+
240+
241+
@pytest.mark.e2e
242+
@pytest.mark.manual_auth
243+
class TestUnauthenticatedAccess:
244+
"""Requests without a valid token are rejected entirely."""
245+
246+
def test_get_thread_without_auth_returns_401(self) -> None:
247+
"""GET /threads/<id> without Authorization header returns 401."""
248+
fake_thread_id = str(uuid.uuid4())
249+
resp = httpx.get(f"{get_server_url()}/threads/{fake_thread_id}", timeout=10.0)
250+
elog("Unauthenticated GET thread", {"status": resp.status_code})
251+
assert resp.status_code == 401, f"Expected 401, got {resp.status_code}"
252+
253+
def test_search_without_auth_returns_401(self) -> None:
254+
"""POST /threads/search without Authorization header returns 401."""
255+
resp = httpx.post(
256+
f"{get_server_url()}/threads/search",
257+
json={"limit": 10},
258+
timeout=10.0,
259+
)
260+
elog("Unauthenticated search", {"status": resp.status_code})
261+
assert resp.status_code == 401, f"Expected 401, got {resp.status_code}"

libs/aegra-api/tests/integration/test_api/test_runs_crud.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,13 @@ def test_wait_for_run_timeout(self):
776776
run = _run_row(status="running")
777777
run.output = {"partial": "data"}
778778

779+
thread = _thread_row()
780+
779781
class Session(DummySessionBase):
780782
async def scalar(self, stmt):
781783
stmt_str = str(stmt).lower()
784+
if "from thread" in stmt_str:
785+
return thread
782786
if "from assistant" in stmt_str:
783787
return assistant
784788
if "from run" in stmt_str:

0 commit comments

Comments
 (0)