Skip to content

Commit 0627e97

Browse files
authored
feat(BA-1846): Separate repository layer from domain service (#5099)
1 parent d809d0f commit 0627e97

File tree

13 files changed

+837
-320
lines changed

13 files changed

+837
-320
lines changed

changes/5099.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Separate repository layer from domain service

src/ai/backend/manager/errors/exceptions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,3 +1266,21 @@ def error_code(cls) -> ErrorCode:
12661266
operation=ErrorOperation.READ,
12671267
error_detail=ErrorDetail.NOT_FOUND,
12681268
)
1269+
1270+
1271+
class DomainDataProcessingError(BackendError, web.HTTPInternalServerError):
1272+
"""
1273+
Error that occurs when processing domain data fails.
1274+
This includes failures in converting database rows to domain data objects.
1275+
"""
1276+
1277+
error_type = "https://api.backend.ai/probs/domain-data-processing-error"
1278+
error_title = "Failed to process domain data."
1279+
1280+
@classmethod
1281+
def error_code(cls) -> ErrorCode:
1282+
return ErrorCode(
1283+
domain=ErrorDomain.DOMAIN,
1284+
operation=ErrorOperation.GENERIC,
1285+
error_detail=ErrorDetail.INTERNAL_ERROR,
1286+
)

src/ai/backend/manager/models/gql_models/domain.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class DomainInput(graphene.InputObjectType):
626626
)
627627
integration_id = graphene.String(required=False, default_value=None)
628628

629-
def to_action(self, domain_name: str) -> CreateDomainAction:
629+
def to_action(self, domain_name: str, user_info: UserInfo) -> CreateDomainAction:
630630
def value_or_none(value):
631631
return value if value is not Undefined else None
632632

@@ -640,6 +640,7 @@ def value_or_none(value):
640640
allowed_docker_registries=value_or_none(self.allowed_docker_registries),
641641
integration_id=value_or_none(self.integration_id),
642642
),
643+
user_info=user_info,
643644
)
644645

645646

@@ -661,9 +662,10 @@ def _convert_field(
661662
return converter(field_value)
662663
return field_value
663664

664-
def to_action(self, domain_name: str) -> ModifyDomainAction:
665+
def to_action(self, domain_name: str, user_info: UserInfo) -> ModifyDomainAction:
665666
return ModifyDomainAction(
666667
domain_name=domain_name,
668+
user_info=user_info,
667669
modifier=DomainModifier(
668670
name=OptionalState[str].from_graphql(self.name),
669671
description=TriState[str].from_graphql(
@@ -709,7 +711,13 @@ async def mutate(
709711
) -> CreateDomain:
710712
ctx: GraphQueryContext = info.context
711713

712-
action: CreateDomainAction = props.to_action(name)
714+
user_info: UserInfo = UserInfo(
715+
id=ctx.user["uuid"],
716+
role=ctx.user["role"],
717+
domain_name=ctx.user["domain_name"],
718+
)
719+
720+
action: CreateDomainAction = props.to_action(name, user_info)
713721
res = await ctx.processors.domain.create_domain.wait_for_complete(action)
714722

715723
domain_data: Optional[DomainData] = res.domain_data
@@ -742,7 +750,13 @@ async def mutate(
742750
) -> ModifyDomain:
743751
ctx: GraphQueryContext = info.context
744752

745-
action = props.to_action(name)
753+
user_info: UserInfo = UserInfo(
754+
id=ctx.user["uuid"],
755+
role=ctx.user["role"],
756+
domain_name=ctx.user["domain_name"],
757+
)
758+
759+
action = props.to_action(name, user_info)
746760
res = await ctx.processors.domain.modify_domain.wait_for_complete(action)
747761

748762
domain_data: Optional[DomainData] = res.domain_data
@@ -771,7 +785,13 @@ class Arguments:
771785
async def mutate(cls, root, info: graphene.ResolveInfo, name: str) -> DeleteDomain:
772786
ctx: GraphQueryContext = info.context
773787

774-
action = DeleteDomainAction(name)
788+
user_info: UserInfo = UserInfo(
789+
id=ctx.user["uuid"],
790+
role=ctx.user["role"],
791+
domain_name=ctx.user["domain_name"],
792+
)
793+
794+
action = DeleteDomainAction(name, user_info)
775795
res = await ctx.processors.domain.delete_domain.wait_for_complete(action)
776796

777797
return cls(ok=res.success, msg=res.description)
@@ -797,7 +817,13 @@ class Arguments:
797817
async def mutate(cls, root, info: graphene.ResolveInfo, name: str) -> PurgeDomain:
798818
ctx: GraphQueryContext = info.context
799819

800-
action = PurgeDomainAction(name)
820+
user_info: UserInfo = UserInfo(
821+
id=ctx.user["uuid"],
822+
role=ctx.user["role"],
823+
domain_name=ctx.user["domain_name"],
824+
)
825+
826+
action = PurgeDomainAction(name, user_info)
801827
res = await ctx.processors.domain.purge_domain.wait_for_complete(action)
802828

803829
return cls(ok=res.success, msg=res.description)
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from typing import Optional
2+
3+
import sqlalchemy as sa
4+
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
5+
from sqlalchemy.ext.asyncio import AsyncSession as SASession
6+
7+
from ai.backend.manager.models import groups
8+
from ai.backend.manager.models.domain import DomainRow, domains
9+
from ai.backend.manager.models.group import ProjectType
10+
from ai.backend.manager.models.kernel import kernels
11+
from ai.backend.manager.models.scaling_group import ScalingGroupForDomainRow
12+
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine, execute_with_txn_retry
13+
from ai.backend.manager.services.domain.types import (
14+
DomainCreator,
15+
DomainData,
16+
DomainModifier,
17+
UserInfo,
18+
)
19+
20+
21+
class AdminDomainRepository:
22+
"""
23+
Repository for admin-specific domain operations that bypass ownership checks.
24+
This should only be used by superadmin users.
25+
"""
26+
27+
_db: ExtendedAsyncSAEngine
28+
29+
def __init__(self, db: ExtendedAsyncSAEngine) -> None:
30+
self._db = db
31+
32+
async def create_domain_force(self, creator: DomainCreator) -> DomainData:
33+
"""
34+
Creates a new domain with model-store group without permission checks.
35+
For superadmin use only.
36+
"""
37+
async with self._db.begin() as conn:
38+
data = creator.fields_to_store()
39+
insert_query = sa.insert(domains).values(data).returning(domains)
40+
result = await conn.execute(insert_query)
41+
row = result.first()
42+
43+
# Create model-store group for the domain
44+
await self._create_model_store_group(conn, creator.name)
45+
46+
if result.rowcount != 1 or row is None:
47+
raise RuntimeError(f"No domain created. rowcount: {result.rowcount}, data: {data}")
48+
49+
assert row is not None
50+
result = DomainData.from_row(row)
51+
assert result is not None
52+
return result
53+
54+
async def modify_domain_force(
55+
self, domain_name: str, modifier: DomainModifier
56+
) -> Optional[DomainData]:
57+
"""
58+
Modifies an existing domain without permission checks.
59+
For superadmin use only.
60+
"""
61+
async with self._db.begin() as conn:
62+
data = modifier.fields_to_update()
63+
update_query = (
64+
sa.update(domains)
65+
.values(data)
66+
.where(domains.c.name == domain_name)
67+
.returning(domains)
68+
)
69+
result = await conn.execute(update_query)
70+
row = result.first()
71+
72+
if result.rowcount == 0:
73+
return None
74+
75+
return DomainData.from_row(row)
76+
77+
async def soft_delete_domain_force(self, domain_name: str) -> bool:
78+
"""
79+
Soft deletes a domain by setting is_active to False without permission checks.
80+
For superadmin use only.
81+
"""
82+
async with self._db.begin() as conn:
83+
update_query = (
84+
sa.update(domains).values({"is_active": False}).where(domains.c.name == domain_name)
85+
)
86+
result = await conn.execute(update_query)
87+
return result.rowcount > 0
88+
89+
async def purge_domain_force(self, domain_name: str) -> bool:
90+
"""
91+
Permanently deletes a domain without validation checks.
92+
For superadmin use only - bypasses all safety checks.
93+
"""
94+
async with self._db.begin() as conn:
95+
# Force clean up kernels
96+
await self._delete_kernels(conn, domain_name)
97+
98+
# Delete domain
99+
delete_query = sa.delete(domains).where(domains.c.name == domain_name)
100+
result = await conn.execute(delete_query)
101+
return result.rowcount > 0
102+
103+
async def create_domain_node_force(
104+
self, creator: DomainCreator, scaling_groups: Optional[list[str]] = None
105+
) -> DomainData:
106+
"""
107+
Creates a domain node with scaling groups without permission checks.
108+
For superadmin use only.
109+
"""
110+
async with self._db.begin_session() as session:
111+
data = creator.fields_to_store()
112+
insert_and_returning = sa.select(DomainRow).from_statement(
113+
sa.insert(DomainRow).values(data).returning(DomainRow)
114+
)
115+
domain_row = await session.scalar(insert_and_returning)
116+
117+
if scaling_groups is not None:
118+
await session.execute(
119+
sa.insert(ScalingGroupForDomainRow),
120+
[
121+
{"scaling_group": sgroup_name, "domain": creator.name}
122+
for sgroup_name in scaling_groups
123+
],
124+
)
125+
126+
await session.commit()
127+
if domain_row is None:
128+
raise RuntimeError(f"Failed to create domain node: {creator.name}")
129+
assert domain_row is not None
130+
result = DomainData.from_row(domain_row)
131+
assert result is not None
132+
return result
133+
134+
async def modify_domain_node_force(
135+
self,
136+
domain_name: str,
137+
modifier_fields: dict,
138+
sgroups_to_add: Optional[set[str]] = None,
139+
sgroups_to_remove: Optional[set[str]] = None,
140+
) -> Optional[DomainData]:
141+
"""
142+
Modifies a domain node with scaling group changes without permission checks.
143+
For superadmin use only.
144+
"""
145+
async with self._db.begin_session() as session:
146+
if sgroups_to_add is not None:
147+
await session.execute(
148+
sa.insert(ScalingGroupForDomainRow),
149+
[
150+
{"scaling_group": sgroup_name, "domain": domain_name}
151+
for sgroup_name in sgroups_to_add
152+
],
153+
)
154+
155+
if sgroups_to_remove is not None:
156+
await session.execute(
157+
sa.delete(ScalingGroupForDomainRow).where(
158+
(ScalingGroupForDomainRow.domain == domain_name)
159+
& (ScalingGroupForDomainRow.scaling_group.in_(sgroups_to_remove))
160+
),
161+
)
162+
163+
update_stmt = (
164+
sa.update(DomainRow)
165+
.where(DomainRow.name == domain_name)
166+
.values(modifier_fields)
167+
.returning(DomainRow)
168+
)
169+
await session.execute(update_stmt)
170+
171+
domain_row = await session.scalar(
172+
sa.select(DomainRow).where(DomainRow.name == domain_name)
173+
)
174+
175+
await session.commit()
176+
return DomainData.from_row(domain_row) if domain_row else None
177+
178+
async def _create_model_store_group(self, conn: SAConnection, domain_name: str) -> None:
179+
"""
180+
Private method to create model-store group for a domain.
181+
"""
182+
model_store_insert_query = sa.insert(groups).values({
183+
"name": "model-store",
184+
"description": "Model Store",
185+
"is_active": True,
186+
"domain_name": domain_name,
187+
"total_resource_slots": {},
188+
"allowed_vfolder_hosts": {},
189+
"integration_id": None,
190+
"resource_policy": "default",
191+
"type": ProjectType.MODEL_STORE,
192+
})
193+
await conn.execute(model_store_insert_query)
194+
195+
async def _delete_kernels(self, conn: SAConnection, domain_name: str) -> int:
196+
"""
197+
Private method to delete all kernels for a domain.
198+
"""
199+
delete_query = sa.delete(kernels).where(kernels.c.domain_name == domain_name)
200+
result = await conn.execute(delete_query)
201+
return result.rowcount
202+
203+
async def create_domain_node_with_permissions_force(
204+
self,
205+
creator: DomainCreator,
206+
user_info: UserInfo,
207+
scaling_groups: Optional[list[str]] = None,
208+
) -> DomainData:
209+
"""
210+
Creates a domain node with scaling groups without permission checks.
211+
For superadmin use only.
212+
"""
213+
214+
async def _insert(db_session: SASession) -> DomainData:
215+
return await self.create_domain_node_force(creator, scaling_groups)
216+
217+
async with self._db.connect() as db_conn:
218+
return await execute_with_txn_retry(_insert, self._db.begin_session, db_conn)
219+
220+
async def modify_domain_node_with_permissions_force(
221+
self,
222+
domain_name: str,
223+
modifier_fields: dict,
224+
user_info: UserInfo,
225+
sgroups_to_add: Optional[set[str]] = None,
226+
sgroups_to_remove: Optional[set[str]] = None,
227+
) -> Optional[DomainData]:
228+
"""
229+
Modifies a domain node with scaling group changes without permission checks.
230+
For superadmin use only.
231+
"""
232+
233+
async def _update(db_session: SASession) -> Optional[DomainData]:
234+
return await self.modify_domain_node_force(
235+
domain_name,
236+
modifier_fields,
237+
sgroups_to_add,
238+
sgroups_to_remove,
239+
)
240+
241+
async with self._db.connect() as db_conn:
242+
return await execute_with_txn_retry(_update, self._db.begin_session, db_conn)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
from dataclasses import dataclass
22
from typing import Self
33

4+
from ai.backend.manager.repositories.domain.admin_repository import AdminDomainRepository
45
from ai.backend.manager.repositories.domain.repository import DomainRepository
56
from ai.backend.manager.repositories.image.repositories import RepositoryArgs
67

78

89
@dataclass
910
class DomainRepositories:
1011
repository: DomainRepository
12+
admin_repository: AdminDomainRepository
1113

1214
@classmethod
1315
def create(cls, args: RepositoryArgs) -> Self:
1416
repository = DomainRepository(args.db)
17+
admin_repository = AdminDomainRepository(args.db)
1518

1619
return cls(
1720
repository=repository,
21+
admin_repository=admin_repository,
1822
)

0 commit comments

Comments
 (0)