Skip to content

feat(BA-1843): Separate repository layer from container registry service #5095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5095.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Separate repository layer from container registry service
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Separate repository layer from container registry service
Separate admin repository layer from container registry service

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current message is appropriate because the important thing is not separating the admin, but separating the repository layer from the service.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the container repository already separated, and isn't this PR about separating the admin container repository layer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't consider it separated because, although there was a repository class, the actual implementation wasn't separated.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Optional

from ai.backend.manager.data.container_registry.types import ContainerRegistryData
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

from .repository import ContainerRegistryRepository


class AdminContainerRegistryRepository:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a permission check hook that runs before the method is executed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This action will all be removed when RBAC is applied, so for now, it is being left as is.

_repository: ContainerRegistryRepository

def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self._repository = ContainerRegistryRepository(db)

async def clear_images_force(
self,
registry_name: str,
project: Optional[str] = None,
) -> ContainerRegistryData:
"""
Forcefully clear images from a container registry without any validation.
This is an admin-only operation that should be used with caution.
"""
return await self._repository.clear_images(registry_name, project)

async def get_by_registry_and_project_force(
self,
registry_name: str,
project: Optional[str] = None,
) -> ContainerRegistryData:
"""
Forcefully get container registry by name and project without any validation.
This is an admin-only operation that should be used with caution.
"""
return await self._repository.get_by_registry_and_project(registry_name, project)

async def get_by_registry_name_force(self, registry_name: str) -> list[ContainerRegistryData]:
"""
Forcefully get container registries by name without any validation.
This is an admin-only operation that should be used with caution.
"""
return await self._repository.get_by_registry_name(registry_name)

async def get_all_force(self) -> list[ContainerRegistryData]:
"""
Forcefully get all container registries without any validation.
This is an admin-only operation that should be used with caution.
"""
return await self._repository.get_all()
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import Self

from ai.backend.manager.repositories.container_registry.admin_repository import (
AdminContainerRegistryRepository,
)
from ai.backend.manager.repositories.container_registry.repository import (
ContainerRegistryRepository,
)
Expand All @@ -10,11 +13,14 @@
@dataclass
class ContainerRegistryRepositories:
repository: ContainerRegistryRepository
admin_repository: AdminContainerRegistryRepository

@classmethod
def create(cls, args: RepositoryArgs) -> Self:
repository = ContainerRegistryRepository(args.db)
admin_repository = AdminContainerRegistryRepository(args.db)

return cls(
repository=repository,
admin_repository=admin_repository,
)
127 changes: 126 additions & 1 deletion src/ai/backend/manager/repositories/container_registry/repository.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,133 @@
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from typing import Optional

import sqlalchemy as sa

from ai.backend.manager.data.container_registry.types import ContainerRegistryData
from ai.backend.manager.data.image.types import ImageStatus
from ai.backend.manager.errors.exceptions import ContainerRegistryNotFound
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.models.image import ImageRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine, SASession


class ContainerRegistryRepository:
_db: ExtendedAsyncSAEngine

def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self._db = db

async def get_by_registry_and_project(
self,
registry_name: str,
project: Optional[str] = None,
) -> ContainerRegistryData:
async with self._db.begin_readonly_session() as session:
result = await self._get_by_registry_and_project(session, registry_name, project)
if not result:
raise ContainerRegistryNotFound()
return result

async def get_by_registry_name(self, registry_name: str) -> list[ContainerRegistryData]:
async with self._db.begin_readonly_session() as session:
return await self._get_by_registry_name(session, registry_name)

async def get_all(self) -> list[ContainerRegistryData]:
async with self._db.begin_readonly_session() as session:
return await self._get_all(session)

async def clear_images(
self,
registry_name: str,
project: Optional[str] = None,
) -> ContainerRegistryData:
async with self._db.begin_session() as session:
# Clear images
update_stmt = (
sa.update(ImageRow)
.where(ImageRow.registry == registry_name)
.where(ImageRow.status != ImageStatus.DELETED)
.values(status=ImageStatus.DELETED)
)
if project:
update_stmt = update_stmt.where(ImageRow.project == project)

await session.execute(update_stmt)

# Return registry data
result = await self._get_by_registry_and_project(session, registry_name, project)
if not result:
raise ContainerRegistryNotFound()
return result

async def get_known_registries(self) -> dict[str, str]:
async with self._db.begin_readonly_session() as session:
from ai.backend.manager.models.container_registry import ContainerRegistryRow

known_registries_map = await ContainerRegistryRow.get_known_container_registries(
session
)

known_registries = {}
for project, registries in known_registries_map.items():
for registry_name, url in registries.items():
if project:
key = f"{project}/{registry_name}"
else:
key = registry_name
known_registries[key] = url.human_repr()

return known_registries

async def get_registry_row_for_scanner(
self,
registry_name: str,
project: Optional[str] = None,
) -> ContainerRegistryRow:
"""
Get the raw ContainerRegistryRow object needed for container registry scanner.
Raises ContainerRegistryNotFound if registry is not found.
"""
async with self._db.begin_readonly_session() as session:
stmt = sa.select(ContainerRegistryRow).where(
ContainerRegistryRow.registry_name == registry_name,
)
if project:
stmt = stmt.where(ContainerRegistryRow.project == project)

row = await session.scalar(stmt)
if not row:
raise ContainerRegistryNotFound()
return row

async def _get_by_registry_and_project(
self,
session: SASession,
registry_name: str,
project: Optional[str] = None,
) -> Optional[ContainerRegistryData]:
stmt = sa.select(ContainerRegistryRow).where(
ContainerRegistryRow.registry_name == registry_name,
)
if project:
stmt = stmt.where(ContainerRegistryRow.project == project)

row = await session.scalar(stmt)
return row.to_dataclass() if row else None

async def _get_by_registry_name(
self,
session: SASession,
registry_name: str,
) -> list[ContainerRegistryData]:
stmt = sa.select(ContainerRegistryRow).where(
ContainerRegistryRow.registry_name == registry_name
)
result = await session.execute(stmt)
rows = result.scalars().all()
return [row.to_dataclass() for row in rows]

async def _get_all(self, session: SASession) -> list[ContainerRegistryData]:
stmt = sa.select(ContainerRegistryRow)
result = await session.execute(stmt)
rows = result.scalars().all()
return [row.to_dataclass() for row in rows]
112 changes: 42 additions & 70 deletions src/ai/backend/manager/services/container_registry/service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sqlalchemy as sa

from ai.backend.manager.container_registry import get_container_registry_cls
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.models.image import ImageRow
from ai.backend.manager.errors.exceptions import ContainerRegistryNotFound
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.repositories.container_registry.admin_repository import (
AdminContainerRegistryRepository,
)
from ai.backend.manager.repositories.container_registry.repository import (
ContainerRegistryRepository,
)
Expand All @@ -28,105 +28,77 @@
RescanImagesActionResult,
)

from ...data.image.types import ImageStatus


class ContainerRegistryService:
_db: ExtendedAsyncSAEngine
_container_registry_repository: ContainerRegistryRepository
_admin_container_registry_repository: AdminContainerRegistryRepository

def __init__(
self,
db: ExtendedAsyncSAEngine,
container_registry_repository: ContainerRegistryRepository,
admin_container_registry_repository: AdminContainerRegistryRepository,
) -> None:
self._db = db
self._container_registry_repository = container_registry_repository
self._admin_container_registry_repository = admin_container_registry_repository

async def rescan_images(self, action: RescanImagesAction) -> RescanImagesActionResult:
registry_name = action.registry
project = action.project

async with self._db.begin_readonly_session() as db_session:
stmt = sa.select(ContainerRegistryRow).where(
ContainerRegistryRow.registry_name == registry_name,
)
if project:
stmt = stmt.where(ContainerRegistryRow.project == project)
# TODO: Raise exception if registry not found or two or more registries found
registry_row: ContainerRegistryRow = await db_session.scalar(stmt)
registry_data = await self._container_registry_repository.get_by_registry_and_project(
registry_name, project
)

scanner_cls = get_container_registry_cls(registry_row)
scanner = scanner_cls(self._db, registry_name, registry_row)
result = await scanner.rescan_single_registry(action.progress_reporter)
registry_row = await self._container_registry_repository.get_registry_row_for_scanner(
registry_name, project
)

scanner_cls = get_container_registry_cls(registry_row)
scanner = scanner_cls(self._db, registry_name, registry_row)
result = await scanner.rescan_single_registry(action.progress_reporter)

return RescanImagesActionResult(
images=result.images, errors=result.errors, registry=registry_row.to_dataclass()
images=result.images, errors=result.errors, registry=registry_data
)

async def clear_images(self, action: ClearImagesAction) -> ClearImagesActionResult:
async with self._db.begin_session() as session:
update_stmt = (
sa.update(ImageRow)
.where(ImageRow.registry == action.registry)
.where(ImageRow.status != ImageStatus.DELETED)
.values(status=ImageStatus.DELETED)
)
if action.project:
update_stmt = update_stmt.where(ImageRow.project == action.project)

await session.execute(update_stmt)

get_registry_row_stmt = sa.select(ContainerRegistryRow).where(
ContainerRegistryRow.registry_name == action.registry,
)
if action.project:
get_registry_row_stmt = get_registry_row_stmt.where(
ContainerRegistryRow.project == action.project
)

registry_row: ContainerRegistryRow = await session.scalar(get_registry_row_stmt)
registry_data = await self._admin_container_registry_repository.clear_images_force(
action.registry, action.project
)

return ClearImagesActionResult(registry=registry_row.to_dataclass())
return ClearImagesActionResult(registry=registry_data)

async def load_container_registries(
self, action: LoadContainerRegistriesAction
) -> LoadContainerRegistriesActionResult:
project = action.project

async with self._db.begin_readonly_session() as db_session:
query = sa.select(ContainerRegistryRow).where(
ContainerRegistryRow.registry_name == action.registry
if action.project is not None:
try:
registry_data = (
await self._container_registry_repository.get_by_registry_and_project(
action.registry, action.project
)
)
registries = [registry_data]
except ContainerRegistryNotFound:
registries = []
else:
registries = await self._container_registry_repository.get_by_registry_name(
action.registry
)
if project is not None:
query = query.where(ContainerRegistryRow.project == project)
result = await db_session.execute(query)
registries = result.scalars().all()

return LoadContainerRegistriesActionResult(
registries=[registry.to_dataclass() for registry in registries]
)
return LoadContainerRegistriesActionResult(registries=registries)

async def load_all_container_registries(
self, action: LoadAllContainerRegistriesAction
self, _action: LoadAllContainerRegistriesAction
) -> LoadAllContainerRegistriesActionResult:
async with self._db.begin_readonly_session() as db_session:
query = sa.select(ContainerRegistryRow)
result = await db_session.execute(query)
registries: list[ContainerRegistryRow] = result.scalars().all()
return LoadAllContainerRegistriesActionResult(
registries=[registry.to_dataclass() for registry in registries]
)
registries = await self._container_registry_repository.get_all()
return LoadAllContainerRegistriesActionResult(registries=registries)

async def get_container_registries(
self, action: GetContainerRegistriesAction
self, _action: GetContainerRegistriesAction
) -> GetContainerRegistriesActionResult:
async with self._db.begin_session() as session:
_registries = await ContainerRegistryRow.get_known_container_registries(session)

known_registries = {}
for project, registries in _registries.items():
for registry_name, url in registries.items():
if project not in known_registries:
known_registries[f"{project}/{registry_name}"] = url.human_repr()
return GetContainerRegistriesActionResult(registries=known_registries)
registries = await self._container_registry_repository.get_known_registries()
return GetContainerRegistriesActionResult(registries=registries)
4 changes: 3 additions & 1 deletion src/ai/backend/manager/services/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def create(cls, args: ServiceArgs) -> Self:
args.agent_registry, repositories.image.repository, repositories.image.admin_repository
)
container_registry_service = ContainerRegistryService(
args.db, repositories.container_registry.repository
args.db,
repositories.container_registry.repository,
repositories.container_registry.admin_repository,
)
vfolder_service = VFolderService(
args.db,
Expand Down
Loading
Loading