Skip to content

Commit ff150ba

Browse files
authored
Implement device fallback to CPU for preconfigured pipelines (#5091)
1 parent f5940f8 commit ff150ba

File tree

9 files changed

+95
-24
lines changed

9 files changed

+95
-24
lines changed

application/backend/app/api/dependencies.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,18 @@ def get_source_update_service(
107107
return SourceUpdateService(event_bus=event_bus, db_session=db)
108108

109109

110+
def get_system_service() -> SystemService:
111+
"""Provides a SystemService instance for system-level operations."""
112+
return SystemService()
113+
114+
110115
def get_pipeline_service(
111116
event_bus: Annotated[EventBus, Depends(get_event_bus)],
112117
db: Annotated[Session, Depends(get_db)],
118+
system_service: Annotated[SystemService, Depends(get_system_service)],
113119
) -> PipelineService:
114120
"""Provides a PipelineService instance ."""
115-
return PipelineService(event_bus=event_bus, db_session=db)
121+
return PipelineService(event_bus=event_bus, db_session=db, system_service=system_service)
116122

117123

118124
def get_pipeline_metrics_service(
@@ -126,11 +132,6 @@ def get_pipeline_metrics_service(
126132
)
127133

128134

129-
def get_system_service() -> SystemService:
130-
"""Provides a SystemService instance for system-level operations."""
131-
return SystemService()
132-
133-
134135
def get_model_service(
135136
data_dir: Annotated[Path, Depends(get_data_dir)],
136137
db: Annotated[Session, Depends(get_db)],

application/backend/app/services/data_collect/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,12 @@ def __init__(self, data_dir: Path, event_bus: EventBus) -> None:
104104
)
105105

106106
def _load_pipeline(self) -> None:
107-
from app.services import LabelService, PipelineService, ProjectService
107+
from app.services import LabelService, PipelineService, ProjectService, SystemService
108108

109109
with get_db_session() as db:
110110
label_service = LabelService(db_session=db)
111-
pipeline_service = PipelineService(event_bus=self.event_bus, db_session=db)
111+
system_service = SystemService()
112+
pipeline_service = PipelineService(event_bus=self.event_bus, db_session=db, system_service=system_service)
112113
pipeline = pipeline_service.get_active_pipeline()
113114
if pipeline is None:
114115
logger.info("No active pipeline found, disabling data collection")

application/backend/app/services/pipeline_service.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from uuid import UUID
55

6+
from loguru import logger
67
from sqlalchemy.orm import Session
78

89
from app.db.schema import PipelineDB
@@ -12,13 +13,16 @@
1213
from app.services.event.event_bus import EventBus, EventType
1314
from app.services.parent_process_guard import parent_process_only
1415

16+
from .system_service import DEFAULT_DEVICE, SystemService
17+
1518
MSG_ERR_DELETE_RUNNING_PIPELINE = "Cannot delete a running pipeline."
1619

1720

1821
class PipelineService:
19-
def __init__(self, event_bus: EventBus, db_session: Session) -> None:
22+
def __init__(self, event_bus: EventBus, db_session: Session, system_service: SystemService) -> None:
2023
self._event_bus: EventBus = event_bus
2124
self._db_session: Session = db_session
25+
self._system_service: SystemService = system_service
2226

2327
def create_pipeline(self, project_id: UUID) -> Pipeline:
2428
pipeline_repo = PipelineRepository(self._db_session)
@@ -31,8 +35,19 @@ def create_pipeline(self, project_id: UUID) -> Pipeline:
3135
def get_active_pipeline(self) -> Pipeline | None:
3236
"""Retrieve an active pipeline."""
3337
pipeline_repo = PipelineRepository(self._db_session)
34-
pipeline = pipeline_repo.get_active_pipeline()
35-
return Pipeline.model_validate(pipeline) if pipeline is not None else None
38+
pipeline_db = pipeline_repo.get_active_pipeline()
39+
if pipeline_db is None:
40+
return None
41+
42+
if not self._system_service.validate_device(pipeline_db.device):
43+
logger.warning(
44+
"The configured device '{}' is not available for pipeline '{}'. Falling back to 'cpu'.",
45+
pipeline_db.device,
46+
pipeline_db.project_id,
47+
)
48+
pipeline_db.device = DEFAULT_DEVICE
49+
pipeline_repo.update(pipeline_db)
50+
return Pipeline.model_validate(pipeline_db)
3651

3752
def get_pipeline_by_id(self, project_id: UUID) -> Pipeline:
3853
"""Retrieve a pipeline by project ID."""

application/backend/app/services/system_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from app.schemas.system import DeviceInfo, DeviceType
1010

1111
DEVICE_PATTERN = re.compile(r"^(cpu|xpu|cuda)(-(\d+))?$")
12+
DEFAULT_DEVICE = "cpu"
1213

1314

1415
class SystemService:

application/backend/tests/integration/project_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def with_pipeline(
7373
model_id: str | None = None,
7474
source_id: str | None = None,
7575
sink_id: str | None = None,
76+
device: str = "cpu",
7677
) -> "ProjectTestDataFactory":
7778
"""Add a pipeline to the project."""
7879
if not self._project:
@@ -84,6 +85,7 @@ def with_pipeline(
8485
model_revision_id=model_id,
8586
source_id=source_id,
8687
sink_id=sink_id,
88+
device=device,
8789
)
8890
return self
8991

application/backend/tests/integration/services/test_dataset_revision_service.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99

1010
from app.db.schema import DatasetItemDB, DatasetRevisionDB, PipelineDB
1111
from app.models import DatasetItemAnnotationStatus, DatasetItemSubset, Pipeline, Project
12-
from app.services import DatasetRevisionService, DatasetService, LabelService, PipelineService, ProjectService
12+
from app.services import (
13+
DatasetRevisionService,
14+
DatasetService,
15+
LabelService,
16+
PipelineService,
17+
ProjectService,
18+
SystemService,
19+
)
1320
from app.services.base import ResourceNotFoundError, ResourceType
1421
from app.services.event.event_bus import EventBus
1522

@@ -21,9 +28,17 @@ def fxt_event_bus() -> EventBus:
2128

2229

2330
@pytest.fixture
24-
def fxt_pipeline_service(fxt_event_bus: EventBus, db_session: Session) -> PipelineService:
31+
def fxt_system_service() -> SystemService:
32+
"""Fixture to create a SystemService instance."""
33+
return SystemService()
34+
35+
36+
@pytest.fixture
37+
def fxt_pipeline_service(
38+
fxt_event_bus: EventBus, db_session: Session, fxt_system_service: SystemService
39+
) -> PipelineService:
2540
"""Fixture to create a PipelineService instance."""
26-
return PipelineService(event_bus=fxt_event_bus, db_session=db_session)
41+
return PipelineService(event_bus=fxt_event_bus, db_session=db_session, system_service=fxt_system_service)
2742

2843

2944
@pytest.fixture

application/backend/tests/integration/services/test_dataset_service.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Project,
2323
Rectangle,
2424
)
25-
from app.services import LabelService, PipelineService, ProjectService
25+
from app.services import LabelService, PipelineService, ProjectService, SystemService
2626
from app.services.base import ResourceNotFoundError, ResourceType
2727
from app.services.dataset_service import (
2828
DatasetItemFilters,
@@ -40,9 +40,17 @@ def fxt_event_bus() -> EventBus:
4040

4141

4242
@pytest.fixture
43-
def fxt_pipeline_service(fxt_event_bus: EventBus, db_session: Session) -> PipelineService:
43+
def fxt_system_service() -> SystemService:
44+
"""Fixture to create a SystemService instance."""
45+
return SystemService()
46+
47+
48+
@pytest.fixture
49+
def fxt_pipeline_service(
50+
fxt_event_bus: EventBus, db_session: Session, fxt_system_service: SystemService
51+
) -> PipelineService:
4452
"""Fixture to create a PipelineService instance."""
45-
return PipelineService(event_bus=fxt_event_bus, db_session=db_session)
53+
return PipelineService(event_bus=fxt_event_bus, db_session=db_session, system_service=fxt_system_service)
4654

4755

4856
@pytest.fixture

application/backend/tests/integration/services/test_pipeline_service.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from app.db.schema import PipelineDB, ProjectDB
1111
from app.models import PipelineStatus
1212
from app.models.data_collection_policy import FixedRateDataCollectionPolicy
13-
from app.services import PipelineService, ResourceNotFoundError, ResourceType
13+
from app.services import PipelineService, ResourceNotFoundError, ResourceType, SystemService
1414
from app.services.event.event_bus import EventType
1515
from tests.integration.project_factory import ProjectTestDataFactory
1616

@@ -21,9 +21,15 @@ class PipelineField(StrEnum):
2121

2222

2323
@pytest.fixture
24-
def fxt_pipeline_service(fxt_event_bus, db_session) -> PipelineService:
24+
def fxt_system_service() -> SystemService:
25+
"""Fixture to create a SystemService instance."""
26+
return SystemService()
27+
28+
29+
@pytest.fixture
30+
def fxt_pipeline_service(fxt_event_bus, db_session, fxt_system_service) -> PipelineService:
2531
"""Fixture to create a PipelineService instance with mocked dependencies."""
26-
return PipelineService(fxt_event_bus, db_session)
32+
return PipelineService(fxt_event_bus, db_session, fxt_system_service)
2733

2834

2935
@pytest.fixture
@@ -33,7 +39,7 @@ def fxt_project_with_pipeline(
3339
"""Fixture to create a ProjectDB with an associated PipelineDB."""
3440

3541
def _create_project_with_pipeline(
36-
is_running: bool, data_policies: list[dict] | None = None
42+
is_running: bool, data_policies: list[dict] | None = None, device: str = "cpu"
3743
) -> tuple[ProjectDB, PipelineDB]:
3844
db_session.add_all(fxt_db_sources)
3945
db_session.add_all(fxt_db_sinks)
@@ -46,6 +52,7 @@ def _create_project_with_pipeline(
4652
model_id=fxt_db_models[0].id,
4753
source_id=fxt_db_sources[0].id,
4854
sink_id=fxt_db_sinks[0].id,
55+
device=device,
4956
)
5057
.with_models(fxt_db_models)
5158
.with_data_policies(data_policies if data_policies else [])
@@ -104,6 +111,19 @@ def test_get_active_pipeline(self, fxt_pipeline_service, fxt_project_with_pipeli
104111
assert active_pipeline is not None
105112
assert active_pipeline.project_id == project_id
106113

114+
def test_get_active_pipeline_device_change(self, fxt_pipeline_service, fxt_project_with_pipeline, db_session):
115+
"""Test retrieving a pipeline when its original device is no longer available."""
116+
db_project, db_pipeline = fxt_project_with_pipeline(is_running=True, data_policies=[], device="xpu-99")
117+
118+
assert db_pipeline.device == "xpu-99"
119+
120+
project_id = UUID(db_project.id)
121+
active_pipeline = fxt_pipeline_service.get_active_pipeline()
122+
123+
assert active_pipeline is not None
124+
assert active_pipeline.project_id == project_id
125+
assert active_pipeline.device == "cpu"
126+
107127
def test_get_non_existent_pipeline(self, fxt_pipeline_service):
108128
"""Test retrieving a non-existent pipeline raises error."""
109129
pipeline_id = uuid4()

application/backend/tests/integration/services/test_project_service.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from app.db.schema import DatasetItemDB, LabelDB, PipelineDB, ProjectDB
1010
from app.models import Label, Task, TaskType
11-
from app.services import LabelService, PipelineService, ResourceWithIdAlreadyExistsError
11+
from app.services import LabelService, PipelineService, ResourceWithIdAlreadyExistsError, SystemService
1212
from app.services.base import ResourceInUseError, ResourceNotFoundError, ResourceType
1313
from app.services.event.event_bus import EventBus
1414
from app.services.label_service import DuplicateLabelsError
@@ -22,9 +22,17 @@ def fxt_event_bus() -> EventBus:
2222

2323

2424
@pytest.fixture
25-
def fxt_pipeline_service(fxt_event_bus: EventBus, db_session: Session) -> PipelineService:
25+
def fxt_system_service() -> SystemService:
26+
"""Fixture to create a SystemService instance."""
27+
return SystemService()
28+
29+
30+
@pytest.fixture
31+
def fxt_pipeline_service(
32+
fxt_event_bus: EventBus, db_session: Session, fxt_system_service: SystemService
33+
) -> PipelineService:
2634
"""Fixture to create a PipelineService instance."""
27-
return PipelineService(event_bus=fxt_event_bus, db_session=db_session)
35+
return PipelineService(event_bus=fxt_event_bus, db_session=db_session, system_service=fxt_system_service)
2836

2937

3038
@pytest.fixture

0 commit comments

Comments
 (0)