Skip to content

Commit ec25b1b

Browse files
committed
feat: Refactor agent registry usage and add mock fixtures for testing
1 parent d0b13f1 commit ec25b1b

File tree

7 files changed

+71
-20
lines changed

7 files changed

+71
-20
lines changed

src/ai/backend/manager/repositories/model_serving/repository.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@
5858
# Layer-specific decorator for model_serving repository
5959
repository_decorator = create_layer_aware_repository_decorator(LayerType.MODEL_SERVING)
6060

61-
if TYPE_CHECKING:
62-
from ai.backend.manager.registry import AgentRegistry
61+
from ai.backend.manager.registry import AgentRegistry
6362

6463

6564
class ModelServingRepository:
@@ -727,7 +726,7 @@ async def resolve_image_for_endpoint_creation(
727726
async def modify_endpoint(
728727
self,
729728
action: ModifyEndpointAction,
730-
agent_registry: "AgentRegistry",
729+
agent_registry: AgentRegistry,
731730
legacy_etcd_config_loader: LegacyEtcdLoader,
732731
storage_manager: StorageSessionManager,
733732
) -> MutationResult:

tests/manager/repositories/model_serving/test_admin_repository.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from unittest.mock import AsyncMock
23

34
import pytest
45
from sqlalchemy.orm.exc import NoResultFound
@@ -9,6 +10,13 @@
910
from .conftest import assert_update_query_executed
1011

1112

13+
@pytest.fixture
14+
def mock_valkey_live():
15+
mock = AsyncMock()
16+
mock.store_live_data = AsyncMock()
17+
return mock
18+
19+
1220
@pytest.mark.asyncio
1321
async def test_get_endpoint_by_id_force_success(
1422
admin_model_serving_repository,
@@ -150,6 +158,7 @@ async def test_update_route_traffic_force_success(
150158
sample_endpoint,
151159
patch_routing_get,
152160
patch_endpoint_get,
161+
mock_valkey_live,
153162
):
154163
"""Test admin force update of route traffic ratio."""
155164
# Arrange
@@ -161,7 +170,7 @@ async def test_update_route_traffic_force_success(
161170

162171
# Act
163172
result = await admin_model_serving_repository.update_route_traffic_force(
164-
route_id, service_id, new_traffic_ratio
173+
mock_valkey_live, route_id, service_id, new_traffic_ratio
165174
)
166175

167176
# Assert

tests/manager/repositories/model_serving/test_create_endpoint_validated.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
13
import pytest
24

35
from .conftest import (
@@ -8,15 +10,24 @@
810
)
911

1012

13+
@pytest.fixture
14+
def mock_agent_registry():
15+
"""Mock agent registry for testing."""
16+
mock = MagicMock()
17+
mock.create_appproxy_endpoint = AsyncMock(return_value="https://test-endpoint.example.com")
18+
return mock
19+
20+
1121
@pytest.mark.asyncio
1222
async def test_create_endpoint_validated_success(
1323
model_serving_repository,
1424
setup_writable_session,
1525
sample_endpoint,
26+
mock_agent_registry,
1627
):
1728
"""Test successful creation of an endpoint."""
1829
# Act
19-
result = await model_serving_repository.create_endpoint_validated(sample_endpoint)
30+
result = await model_serving_repository.create_endpoint_validated(sample_endpoint, mock_agent_registry)
2031

2132
# Assert
2233
assert_basic_endpoint_result(result, sample_endpoint)
@@ -48,6 +59,7 @@ async def test_create_endpoint_validated_with_configurations(
4859
sample_vfolder,
4960
endpoint_config,
5061
expected_attrs,
62+
mock_agent_registry,
5163
):
5264
"""Test creation of endpoints with different configurations."""
5365
# Arrange
@@ -57,7 +69,7 @@ async def test_create_endpoint_validated_with_configurations(
5769
endpoint_row = create_full_featured_endpoint(sample_user, sample_image, sample_vfolder)
5870

5971
# Act
60-
result = await model_serving_repository.create_endpoint_validated(endpoint_row)
72+
result = await model_serving_repository.create_endpoint_validated(endpoint_row, mock_agent_registry)
6173

6274
# Assert
6375
assert_basic_endpoint_result(result, endpoint_row)
@@ -75,13 +87,14 @@ async def test_create_endpoint_validated_transaction_handling(
7587
mock_db_engine,
7688
mock_session,
7789
sample_endpoint,
90+
mock_agent_registry,
7891
):
7992
"""Test that creation properly handles database transactions."""
8093
# Arrange
8194
setup_db_session_mock(mock_db_engine, mock_session)
8295

8396
# Act
84-
result = await model_serving_repository.create_endpoint_validated(sample_endpoint)
97+
result = await model_serving_repository.create_endpoint_validated(sample_endpoint, mock_agent_registry)
8598

8699
# Assert
87100
assert_basic_endpoint_result(result, sample_endpoint)

tests/manager/repositories/model_serving/test_get_endpoint_by_id_validated.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from unittest.mock import AsyncMock
23

34
import pytest
45
from sqlalchemy.orm.exc import NoResultFound
@@ -9,6 +10,13 @@
910
from .conftest import assert_update_query_executed
1011

1112

13+
@pytest.fixture
14+
def mock_valkey_live():
15+
mock = AsyncMock()
16+
mock.store_live_data = AsyncMock()
17+
return mock
18+
19+
1220
@pytest.mark.asyncio
1321
async def test_get_endpoint_by_id_force_success(
1422
admin_model_serving_repository,
@@ -150,6 +158,7 @@ async def test_update_route_traffic_force_success(
150158
sample_endpoint,
151159
patch_routing_get,
152160
patch_endpoint_get,
161+
mock_valkey_live,
153162
):
154163
"""Test admin force update of route traffic ratio."""
155164
# Arrange
@@ -161,7 +170,7 @@ async def test_update_route_traffic_force_success(
161170

162171
# Act
163172
result = await admin_model_serving_repository.update_route_traffic_force(
164-
route_id, service_id, new_traffic_ratio
173+
mock_valkey_live, route_id, service_id, new_traffic_ratio
165174
)
166175

167176
# Assert

tests/manager/services/model_serving/actions/test_update_route.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import aiohttp
55
import pytest
66

7-
from ai.backend.manager.errors.service import RouteNotFound
7+
from ai.backend.manager.errors.service import ModelServiceNotFound
88
from ai.backend.manager.models.user import UserRole
99
from ai.backend.manager.services.model_serving.actions.update_route import (
1010
UpdateRouteAction,
@@ -60,10 +60,10 @@ def mock_get_endpoint_for_appproxy_update(mocker, mock_repositories):
6060

6161

6262
@pytest.fixture
63-
def mock_update_appproxy_endpoint_routes(mocker, mock_repositories):
63+
def mock_notify_endpoint_route_update_to_appproxy(mocker, mock_agent_registry):
6464
mock = mocker.patch.object(
65-
mock_repositories.repository,
66-
"update_appproxy_endpoint_routes",
65+
mock_agent_registry,
66+
"notify_endpoint_route_update_to_appproxy",
6767
new_callable=AsyncMock,
6868
)
6969
mock.return_value = None
@@ -130,7 +130,7 @@ async def test_update_route(
130130
mock_update_route_traffic_force,
131131
mock_update_route_traffic_validated,
132132
mock_get_endpoint_for_appproxy_update,
133-
mock_update_appproxy_endpoint_routes,
133+
mock_notify_endpoint_route_update_to_appproxy,
134134
):
135135
# Mock endpoint data for route update
136136
mock_endpoint_data = MagicMock(
@@ -182,7 +182,7 @@ async def update_route(action: UpdateRouteAction):
182182
route_id=uuid.UUID("99999999-9999-9999-9999-999999999999"),
183183
traffic_ratio=0.5,
184184
),
185-
RouteNotFound,
185+
ModelServiceNotFound,
186186
),
187187
],
188188
)
@@ -209,7 +209,7 @@ async def test_update_route_appproxy_failure(
209209
mock_check_requester_access_update_route,
210210
mock_update_route_traffic_validated,
211211
mock_get_endpoint_for_appproxy_update,
212-
mock_update_appproxy_endpoint_routes,
212+
mock_notify_endpoint_route_update_to_appproxy,
213213
):
214214
action = UpdateRouteAction(
215215
requester_ctx=RequesterCtx(
@@ -228,9 +228,8 @@ async def test_update_route_appproxy_failure(
228228
mock_get_endpoint_for_appproxy_update.return_value = MagicMock(id=action.service_id)
229229

230230
# Mock AppProxy communication failure
231-
mock_update_appproxy_endpoint_routes.side_effect = aiohttp.ClientError("Connection failed")
232-
233-
result = await model_serving_processors.update_route.wait_for_complete(action)
231+
mock_notify_endpoint_route_update_to_appproxy.side_effect = aiohttp.ClientError("Connection failed")
234232

235-
# Should still return success despite AppProxy failure
236-
assert result.success is True
233+
# AppProxy failure should propagate as exception
234+
with pytest.raises(aiohttp.ClientError, match="Connection failed"):
235+
await model_serving_processors.update_route.wait_for_complete(action)

tests/manager/services/model_serving/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def mock_background_task_manager():
8787
return mock_background_task_manager
8888

8989

90+
@pytest.fixture
91+
def mock_valkey_live():
92+
mock = MagicMock()
93+
mock.store_live_data = AsyncMock()
94+
mock.get_live_data = AsyncMock()
95+
mock.delete_live_data = AsyncMock()
96+
return mock
97+
98+
9099
@pytest.fixture
91100
def model_serving_service(
92101
database_fixture,
@@ -96,6 +105,7 @@ def model_serving_service(
96105
mock_agent_registry,
97106
mock_background_task_manager,
98107
mock_config_provider,
108+
mock_valkey_live,
99109
mock_repositories,
100110
) -> ModelServingService:
101111
return ModelServingService(
@@ -104,6 +114,7 @@ def model_serving_service(
104114
event_dispatcher=mock_event_dispatcher,
105115
storage_manager=mock_storage_manager,
106116
config_provider=mock_config_provider,
117+
valkey_live=mock_valkey_live,
107118
repository=mock_repositories.repository,
108119
admin_repository=mock_repositories.admin_repository,
109120
)

tests/manager/services/model_serving/fixtures.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,15 @@ def mock_background_task_manager():
8080
return mock_background_task_manager
8181

8282

83+
@pytest.fixture
84+
def mock_valkey_live():
85+
mock = MagicMock()
86+
mock.store_live_data = AsyncMock()
87+
mock.get_live_data = AsyncMock()
88+
mock.delete_live_data = AsyncMock()
89+
return mock
90+
91+
8392
@pytest.fixture
8493
def mock_service(
8594
database_fixture, # noqa: ARG001
@@ -89,6 +98,7 @@ def mock_service(
8998
mock_agent_registry,
9099
mock_background_task_manager,
91100
mock_config_provider,
101+
mock_valkey_live,
92102
mock_repositories,
93103
) -> ModelServingService:
94104
return ModelServingService(
@@ -97,6 +107,7 @@ def mock_service(
97107
event_dispatcher=mock_event_dispatcher,
98108
storage_manager=mock_storage_manager,
99109
config_provider=mock_config_provider,
110+
valkey_live=mock_valkey_live,
100111
repository=mock_repositories.repository,
101112
admin_repository=mock_repositories.admin_repository,
102113
)

0 commit comments

Comments
 (0)