Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 3c3eaa8

Browse files
committed
fix database queries
1 parent bba154c commit 3c3eaa8

File tree

4 files changed

+24
-44
lines changed

4 files changed

+24
-44
lines changed

src/codegate/api/v1.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -392,34 +392,31 @@ async def get_workspace_alerts(
392392
except Exception:
393393
logger.exception("Error while getting workspace")
394394
raise HTTPException(status_code=500, detail="Internal server error")
395-
396-
total_alerts = 0
397-
fetched_alerts = []
395+
398396
offset = (page - 1) * page_size
399-
batch_size = page_size * 2 # fetch more alerts per batch to allow deduplication
397+
fetched_alerts = []
400398

401399
while len(fetched_alerts) < page_size:
402-
alerts_batch, total_alerts = await dbreader.get_alerts_by_workspace(
400+
alerts_batch = await dbreader.get_alerts_by_workspace(
403401
ws.id, AlertSeverity.CRITICAL.value, page_size, offset
404402
)
405403
if not alerts_batch:
406404
break
407405

408406
dedup_alerts = await v1_processing.remove_duplicate_alerts(alerts_batch)
409407
fetched_alerts.extend(dedup_alerts)
410-
offset += batch_size
408+
offset += page_size
411409

412410
final_alerts = fetched_alerts[:page_size]
411+
total_alerts = len(fetched_alerts)
412+
413413
prompt_ids = list({alert.prompt_id for alert in final_alerts if alert.prompt_id})
414414
prompts_outputs = await dbreader.get_prompts_with_output(prompt_ids)
415415
alert_conversations = await v1_processing.parse_get_alert_conversation(
416416
final_alerts, prompts_outputs
417417
)
418418
return {
419419
"page": page,
420-
"page_size": page_size,
421-
"total_alerts": total_alerts,
422-
"total_pages": (total_alerts + page_size - 1) // page_size,
423420
"alerts": alert_conversations,
424421
}
425422

src/codegate/db/connection.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from alembic import command as alembic_command
99
from alembic.config import Config as AlembicConfig
1010
from pydantic import BaseModel
11-
from sqlalchemy import CursorResult, TextClause, event, text
11+
from sqlalchemy import CursorResult, TextClause, bindparam, event, text
1212
from sqlalchemy.engine import Engine
1313
from sqlalchemy.exc import IntegrityError, OperationalError
1414
from sqlalchemy.ext.asyncio import create_async_engine
@@ -587,11 +587,12 @@ async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPrompt
587587
o.output_cost
588588
FROM prompts p
589589
LEFT JOIN outputs o ON p.id = o.prompt_id
590-
WHERE p.id IN :prompt_ids
590+
WHERE (p.id IN :prompt_ids)
591591
ORDER BY o.timestamp DESC
592592
"""
593-
)
594-
conditions = {"prompt_ids": tuple(prompt_ids)}
593+
).bindparams(bindparam("prompt_ids", expanding=True))
594+
595+
conditions = {"prompt_ids": prompt_ids if prompt_ids else None}
595596
prompts = await self._exec_select_conditions_to_pydantic(
596597
GetPromptWithOutputsRow, sql, conditions, should_raise=True
597598
)
@@ -659,13 +660,23 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id(
659660

660661
return list(prompts_dict.values())
661662

663+
async def _exec_select_count(self, sql_command: str, conditions: dict) -> int:
664+
"""Executes a COUNT SQL command and returns an integer result."""
665+
async with self._async_db_engine.begin() as conn:
666+
try:
667+
result = await conn.execute(text(sql_command), conditions)
668+
return result.scalar_one() # Ensures it returns exactly one integer value
669+
except Exception as e:
670+
logger.error(f"Failed to execute COUNT query.", error=str(e))
671+
return 0 # Return 0 in case of failure to avoid crashes
672+
662673
async def get_alerts_by_workspace(
663674
self,
664675
workspace_id: str,
665676
trigger_category: Optional[str] = None,
666677
limit: int = API_DEFAULT_PAGE_SIZE,
667678
offset: int = 0,
668-
) -> Tuple[List[Alert], int]:
679+
) -> List[Alert]:
669680
sql = text(
670681
"""
671682
SELECT
@@ -691,25 +702,10 @@ async def get_alerts_by_workspace(
691702
conditions["limit"] = limit
692703
conditions["offset"] = offset
693704

694-
alerts = await self._exec_select_conditions_to_pydantic(
705+
return await self._exec_select_conditions_to_pydantic(
695706
Alert, sql, conditions, should_raise=True
696707
)
697708

698-
# Count total alerts for pagination
699-
count_sql = text(
700-
"""
701-
SELECT COUNT(*)
702-
FROM alerts a
703-
INNER JOIN prompts p ON p.id = a.prompt_id
704-
WHERE p.workspace_id = :workspace_id
705-
"""
706-
)
707-
if trigger_category:
708-
count_sql = text(count_sql.text + " AND a.trigger_category = :trigger_category")
709-
710-
total_alerts = await self._exec_select_count(count_sql, conditions)
711-
return alerts, total_alerts
712-
713709
async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
714710
sql = text(
715711
"""

src/codegate/workspaces/crud.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,8 @@ async def hard_delete_workspace(self, workspace_name: str):
213213
return
214214

215215
async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow:
216-
print("i get by name")
217216
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
218-
print("workspace is")
219-
print(workspace)
220217
if not workspace:
221-
print("in not exist")
222218
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
223219
return workspace
224220

tests/api/test_v1_api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,13 @@ async def test_get_workspace_alerts_empty(mock_ws):
109109
"""Test when no alerts are found (empty list)"""
110110
with (
111111
patch("codegate.workspaces.crud.WorkspaceCrud.get_workspace_by_name", return_value=mock_ws),
112-
patch("codegate.db.connection.DbReader.get_alerts_by_workspace", return_value=([], 0)),
112+
patch("codegate.db.connection.DbReader.get_alerts_by_workspace", return_value=[]),
113113
):
114114

115115
response = client.get("/workspaces/test_workspace/alerts?page=1&page_size=10")
116116
assert response.status_code == 200
117117
assert response.json() == {
118118
"page": 1,
119-
"page_size": 10,
120-
"total_alerts": 0,
121-
"total_pages": 0,
122119
"alerts": [],
123120
}
124121

@@ -141,9 +138,6 @@ async def test_get_workspace_alerts_with_results(mock_ws, mock_alerts, mock_prom
141138
assert response.status_code == 200
142139
data = response.json()
143140
assert data["page"] == 1
144-
assert data["page_size"] == 2
145-
assert data["total_alerts"] == 2
146-
assert data["total_pages"] == 1
147141
assert len(data["alerts"]) == 2
148142

149143

@@ -167,7 +161,4 @@ async def test_get_workspace_alerts_deduplication(mock_ws, mock_alerts, mock_pro
167161
assert response.status_code == 200
168162
data = response.json()
169163
assert data["page"] == 1
170-
assert data["page_size"] == 2
171-
assert data["total_alerts"] == 2 # Total alerts remain the same
172-
assert data["total_pages"] == 1
173164
assert len(data["alerts"]) == 1 # Only one alert left after deduplication

0 commit comments

Comments
 (0)