Skip to content

Commit b7a7060

Browse files
kyujin-choHyeockJinKim
authored andcommitted
feat(BA-992): Offload health check capability to AppProxy (#5134)
Co-authored-by: HyeockJinKim <[email protected]>
1 parent aa655d7 commit b7a7060

File tree

28 files changed

+538
-459
lines changed

28 files changed

+538
-459
lines changed

changes/5134.breaking.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Health check capability temporarily broken on OSS AppProxy due to architectural changes
2+
- Users must disable health check feature in `model-definition.yaml` to use model services on Open Source Backend.AI
3+
- OSS AppProxy support will be restored in future releases

changes/5134.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Offload model service health check architecture to AppProxy with Redis-based route management for improved scalability and real-time endpoint monitoring

src/ai/backend/agent/agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2934,15 +2934,11 @@ async def start_and_monitor_model_service_health(
29342934
# if everything went well then krunner itself will report the status via zmq
29352935
await self.anycast_and_broadcast_event(
29362936
ModelServiceStatusAnycastEvent(
2937-
kernel_obj.kernel_id,
29382937
kernel_obj.session_id,
2939-
model["name"],
29402938
ModelServiceStatus.UNHEALTHY,
29412939
),
29422940
ModelServiceStatusBroadcastEvent(
2943-
kernel_obj.kernel_id,
29442941
kernel_obj.session_id,
2945-
model["name"],
29462942
ModelServiceStatus.UNHEALTHY,
29472943
),
29482944
)

src/ai/backend/agent/kernel.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,11 @@
5050
from ai.backend.common.events.event_types.kernel.types import (
5151
KernelLifecycleEventReason,
5252
)
53-
from ai.backend.common.events.event_types.model_serving.anycast import (
54-
ModelServiceStatusAnycastEvent,
55-
)
5653
from ai.backend.common.json import load_json
5754
from ai.backend.common.types import (
5855
AgentId,
5956
CommitStatus,
6057
KernelId,
61-
ModelServiceStatus,
6258
ServicePort,
6359
SessionId,
6460
SessionTypes,
@@ -1129,18 +1125,8 @@ async def read_output(self) -> None:
11291125
case b"model-service-result":
11301126
await self.model_service_queue.put(msg_data)
11311127
case b"model-service-status":
1132-
response = load_json(msg_data)
1133-
event = ModelServiceStatusAnycastEvent(
1134-
self.kernel_id,
1135-
self.session_id,
1136-
response["model_name"],
1137-
(
1138-
ModelServiceStatus.HEALTHY
1139-
if response["is_healthy"]
1140-
else ModelServiceStatus.UNHEALTHY
1141-
),
1142-
)
1143-
await self.event_producer.anycast_event(event)
1128+
# no-op
1129+
pass
11441130
case b"apps-result":
11451131
await self.service_apps_info_queue.put(msg_data)
11461132
case b"stdout":

src/ai/backend/common/data/config/types.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass
2-
from typing import Optional
2+
from typing import Annotated, Optional
3+
4+
from pydantic import BaseModel, Field
35

46
from ai.backend.common.typed_validators import HostPortPair
57

@@ -10,3 +12,15 @@ class EtcdConfigData:
1012
addr: HostPortPair
1113
user: Optional[str]
1214
password: Optional[str]
15+
16+
17+
class HealthCheckConfig(BaseModel):
18+
"""
19+
Health check configuration matching model-definition.yaml schema
20+
"""
21+
22+
interval: Annotated[float, Field(default=10.0, ge=0)] = 10.0
23+
path: str
24+
max_retries: Annotated[int, Field(default=10, ge=1)] = 10
25+
max_wait_time: Annotated[float, Field(default=15.0, ge=0)] = 15.0
26+
expected_status_code: Annotated[int, Field(default=200, ge=100, le=599)] = 200
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import uuid
2+
from dataclasses import dataclass
3+
from typing import Optional, override
4+
5+
from ai.backend.common.events.types import AbstractEvent, EventDomain
6+
from ai.backend.common.events.user_event.user_event import UserEvent
7+
from ai.backend.common.types import ModelServiceStatus, SessionId
8+
9+
10+
@dataclass
11+
class ModelServiceStatusEventArgs(AbstractEvent):
12+
session_id: SessionId
13+
new_status: ModelServiceStatus
14+
15+
def serialize(self) -> tuple:
16+
return (
17+
str(self.session_id),
18+
self.new_status.value,
19+
)
20+
21+
@classmethod
22+
def deserialize(cls, value: tuple):
23+
return cls(
24+
session_id=SessionId(uuid.UUID(value[0])),
25+
new_status=ModelServiceStatus(value[1]),
26+
)
27+
28+
@classmethod
29+
@override
30+
def event_domain(cls) -> EventDomain:
31+
return EventDomain.MODEL_SERVING
32+
33+
@override
34+
def domain_id(self) -> Optional[str]:
35+
return None
36+
37+
@override
38+
def user_event(self) -> Optional[UserEvent]:
39+
return None

src/ai/backend/common/events/event_types/model_serving/anycast.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,81 +4,81 @@
44

55
from ai.backend.common.events.types import AbstractAnycastEvent, EventDomain
66
from ai.backend.common.events.user_event.user_event import UserEvent
7-
from ai.backend.common.types import KernelId, ModelServiceStatus, SessionId
7+
8+
from . import ModelServiceStatusEventArgs
9+
10+
11+
class ModelServiceStatusAnycastEvent(ModelServiceStatusEventArgs, AbstractAnycastEvent):
12+
@classmethod
13+
@override
14+
def event_name(cls) -> str:
15+
return "model_service_status_updated"
816

917

1018
@dataclass
11-
class ModelServiceStatusEventArgs(AbstractAnycastEvent):
12-
kernel_id: KernelId
13-
session_id: SessionId
14-
model_name: str
15-
new_status: ModelServiceStatus
19+
class RouteCreationEvent(AbstractAnycastEvent):
20+
route_id: uuid.UUID
1621

1722
def serialize(self) -> tuple:
18-
return (
19-
str(self.kernel_id),
20-
str(self.session_id),
21-
self.model_name,
22-
self.new_status.value,
23-
)
23+
return (str(self.route_id),)
2424

2525
@classmethod
2626
def deserialize(cls, value: tuple):
27-
return cls(
28-
kernel_id=KernelId(uuid.UUID(value[0])),
29-
session_id=SessionId(uuid.UUID(value[1])),
30-
model_name=value[2],
31-
new_status=ModelServiceStatus(value[3]),
32-
)
27+
return cls(uuid.UUID(value[0]))
3328

3429
@classmethod
3530
@override
3631
def event_domain(cls) -> EventDomain:
37-
return EventDomain.MODEL_SERVING
32+
return EventDomain.MODEL_ROUTE
3833

3934
@override
4035
def domain_id(self) -> Optional[str]:
41-
return None
36+
return str(self.route_id)
4237

4338
@override
4439
def user_event(self) -> Optional[UserEvent]:
4540
return None
4641

4742

48-
class ModelServiceStatusAnycastEvent(ModelServiceStatusEventArgs):
43+
class RouteCreatedAnycastEvent(RouteCreationEvent):
4944
@classmethod
5045
@override
5146
def event_name(cls) -> str:
52-
return "model_service_status_updated"
47+
return "route_created"
48+
49+
50+
class RouteTerminatingEvent(RouteCreationEvent):
51+
@classmethod
52+
@override
53+
def event_name(cls) -> str:
54+
return "route_terminating"
5355

5456

5557
@dataclass
56-
class RouteCreationEvent(AbstractAnycastEvent):
57-
route_id: uuid.UUID
58+
class EndpointRouteListUpdatedEvent(AbstractAnycastEvent):
59+
endpoint_id: uuid.UUID
5860

5961
def serialize(self) -> tuple:
60-
return (str(self.route_id),)
62+
return (str(self.endpoint_id),)
6163

6264
@classmethod
6365
def deserialize(cls, value: tuple):
6466
return cls(uuid.UUID(value[0]))
6567

68+
@classmethod
69+
@override
70+
def event_name(cls) -> str:
71+
return "endpoint_route_list_updated"
72+
6673
@classmethod
6774
@override
6875
def event_domain(cls) -> EventDomain:
6976
return EventDomain.MODEL_ROUTE
7077

7178
@override
7279
def domain_id(self) -> Optional[str]:
73-
return str(self.route_id)
80+
return str(self.endpoint_id)
7481

7582
@override
7683
def user_event(self) -> Optional[UserEvent]:
7784
return None
78-
79-
80-
class RouteCreatedAnycastEvent(RouteCreationEvent):
81-
@classmethod
82-
@override
83-
def event_name(cls) -> str:
84-
return "route_created"

src/ai/backend/common/events/event_types/model_serving/broadcast.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,11 @@
1-
import uuid
2-
from dataclasses import dataclass
3-
from typing import Optional, override
1+
from typing import override
42

5-
from ai.backend.common.events.types import AbstractBroadcastEvent, EventDomain
6-
from ai.backend.common.events.user_event.user_event import UserEvent
7-
from ai.backend.common.types import KernelId, ModelServiceStatus, SessionId
3+
from ai.backend.common.events.types import AbstractBroadcastEvent
84

5+
from . import ModelServiceStatusEventArgs
96

10-
@dataclass
11-
class ModelServiceStatusEventArgs(AbstractBroadcastEvent):
12-
kernel_id: KernelId
13-
session_id: SessionId
14-
model_name: str
15-
new_status: ModelServiceStatus
167

17-
def serialize(self) -> tuple:
18-
return (
19-
str(self.kernel_id),
20-
str(self.session_id),
21-
self.model_name,
22-
self.new_status.value,
23-
)
24-
25-
@classmethod
26-
def deserialize(cls, value: tuple):
27-
return cls(
28-
kernel_id=KernelId(uuid.UUID(value[0])),
29-
session_id=SessionId(uuid.UUID(value[1])),
30-
model_name=value[2],
31-
new_status=ModelServiceStatus(value[3]),
32-
)
33-
34-
@classmethod
35-
@override
36-
def event_domain(cls) -> EventDomain:
37-
return EventDomain.MODEL_SERVING
38-
39-
@override
40-
def domain_id(self) -> Optional[str]:
41-
return None
42-
43-
@override
44-
def user_event(self) -> Optional[UserEvent]:
45-
return None
46-
47-
48-
class ModelServiceStatusBroadcastEvent(ModelServiceStatusEventArgs):
8+
class ModelServiceStatusBroadcastEvent(ModelServiceStatusEventArgs, AbstractBroadcastEvent):
499
@classmethod
5010
@override
5111
def event_name(cls) -> str:

src/ai/backend/manager/api/service.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AccessKey,
3131
ClusterMode,
3232
RuntimeVariant,
33+
VFolderID,
3334
VFolderMount,
3435
VFolderUsageMode,
3536
)
@@ -74,7 +75,7 @@
7475
from ..errors.api import InvalidAPIParameters
7576
from ..errors.storage import VFolderNotFound
7677
from ..models import (
77-
ModelServicePredicateChecker,
78+
ModelServiceHelper,
7879
UserRole,
7980
UserRow,
8081
query_accessible_vfolders,
@@ -521,7 +522,7 @@ async def _validate(request: web.Request, params: NewServiceRequestModel) -> Val
521522
raise InvalidAPIParameters(f"Cannot spawn more than {_m} sessions for a single service")
522523

523524
async with root_ctx.db.begin_readonly() as conn:
524-
checked_scaling_group = await ModelServicePredicateChecker.check_scaling_group(
525+
checked_scaling_group = await ModelServiceHelper.check_scaling_group(
525526
conn,
526527
params.config.scaling_group,
527528
owner_access_key,
@@ -579,7 +580,7 @@ async def _validate(request: web.Request, params: NewServiceRequestModel) -> Val
579580

580581
model_id = folder_row["id"]
581582

582-
vfolder_mounts = await ModelServicePredicateChecker.check_extra_mounts(
583+
vfolder_mounts = await ModelServiceHelper.check_extra_mounts(
583584
conn,
584585
root_ctx.config_provider.legacy_etcd_config_loader,
585586
root_ctx.storage_manager,
@@ -596,11 +597,19 @@ async def _validate(request: web.Request, params: NewServiceRequestModel) -> Val
596597
)
597598

598599
if params.runtime_variant == RuntimeVariant.CUSTOM:
599-
yaml_path = await ModelServicePredicateChecker.validate_model_definition(
600+
vfid = VFolderID(folder_row["quota_scope_id"], folder_row["id"])
601+
yaml_path = await ModelServiceHelper.validate_model_definition_file_exists(
600602
root_ctx.storage_manager,
601-
folder_row,
603+
folder_row["host"],
604+
vfid,
602605
params.config.model_definition_path,
603606
)
607+
await ModelServiceHelper.validate_model_definition(
608+
root_ctx.storage_manager,
609+
folder_row["host"],
610+
vfid,
611+
yaml_path,
612+
)
604613
else:
605614
if (
606615
params.runtime_variant != RuntimeVariant.CMD

src/ai/backend/manager/clients/wsproxy/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Mapping
2+
from uuid import UUID
23

34
import aiohttp
45

@@ -10,7 +11,7 @@ def __init__(self, address: str, token: str) -> None:
1011

1112
async def create_endpoint(
1213
self,
13-
endpoint_id: str,
14+
endpoint_id: UUID,
1415
body: Mapping[str, Any],
1516
) -> dict[str, Any]:
1617
async with aiohttp.ClientSession() as session:
@@ -26,7 +27,7 @@ async def create_endpoint(
2627

2728
async def delete_endpoint(
2829
self,
29-
endpoint_id: str,
30+
endpoint_id: UUID,
3031
) -> None:
3132
async with aiohttp.ClientSession() as session:
3233
async with session.delete(

0 commit comments

Comments
 (0)