Skip to content

Commit f5940f8

Browse files
authored
Fix inconsistent key names in PATCH pipeline endpoint (#5089)
1 parent 554bd0c commit f5940f8

File tree

10 files changed

+59
-20
lines changed

10 files changed

+59
-20
lines changed

application/backend/app/api/routers/pipelines.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,21 @@
2020
router = APIRouter(prefix="/api/projects/{project_id}/pipeline", tags=["Pipelines"])
2121

2222
UPDATE_PIPELINE_BODY_DESCRIPTION = """
23-
Partial pipeline configuration update. May contain any subset of fields including 'name', 'source_id',
24-
'sink_id', or 'model_id'. Fields not included in the request will remain unchanged.
23+
Partial pipeline configuration update. May contain any subset of fields including 'device', 'data_collection_policies',
24+
'source_id', 'sink_id', or 'model_id'. Fields not included in the request will remain unchanged.
2525
"""
2626
UPDATE_PIPELINE_BODY_EXAMPLES = {
2727
"switch_model": Example(
2828
summary="Switch active model",
29-
description="Change the active model for the pipeline",
29+
description="Change the active model of the pipeline",
3030
value={
3131
"model_id": "c1feaabc-da2b-442e-9b3e-55c11c2c2ff3",
3232
},
3333
),
3434
"reconfigure": Example(
3535
summary="Reconfigure pipeline",
36-
description="Change the name, source and sink of the pipeline",
36+
description="Change the source and the sink of the pipeline",
3737
value={
38-
"name": "Updated Production Pipeline",
3938
"source_id": "e3cbd8d0-17b8-463e-85a2-4aaed031674e",
4039
"sink_id": "c6787c06-964b-4097-8eca-238b8cf79fc9",
4140
},

application/backend/app/models/pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any
66
from uuid import UUID
77

8-
from pydantic import Field, model_validator
8+
from pydantic import AliasChoices, Field, model_validator
99

1010
from .base import BaseEntity
1111
from .data_collection_policy import DataCollectionPolicy
@@ -41,7 +41,7 @@ class Pipeline(BaseEntity):
4141
model_revision: The model revision to use for processing, None if no model is selected.
4242
source_id: UUID reference to the source entity.
4343
sink_id: UUID reference to the sink entity.
44-
model_revision_id: UUID reference to the model revision entity.
44+
model_id: UUID reference to the model revision entity.
4545
status: Current operational status of the pipeline (IDLE or RUNNING).
4646
data_collection_policies: List of policies governing data collection behavior during pipeline execution.
4747
device: The device used for model inference (e.g., 'cpu', 'xpu', 'cuda', 'xpu-1', etc.).
@@ -56,7 +56,7 @@ class Pipeline(BaseEntity):
5656
model_revision: ModelRevision | None = None
5757
source_id: UUID | None = None
5858
sink_id: UUID | None = None
59-
model_revision_id: UUID | None = None
59+
model_id: UUID | None = Field(default=None, validation_alias=AliasChoices("model_revision_id", "model_id"))
6060
status: PipelineStatus = PipelineStatus.IDLE
6161
data_collection_policies: list[DataCollectionPolicy] = Field(default_factory=list)
6262
device: str = Field(default="cpu", pattern=r"^(cpu|xpu|cuda)(-\d+)?$")
@@ -73,7 +73,7 @@ def set_status_from_is_running(cls, data: Any) -> Any:
7373
@model_validator(mode="after")
7474
def validate_running_status(self) -> "Pipeline":
7575
if self.status == PipelineStatus.RUNNING and any(
76-
x is None for x in (self.source_id, self.sink_id, self.model_revision_id)
76+
x is None for x in (self.source_id, self.sink_id, self.model_id)
7777
):
7878
raise ValueError("Pipeline cannot be in 'running' state when source, sink, or model is not configured.")
7979
return self

application/backend/app/services/event/event_bus.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class EventType(StrEnum):
1111
SOURCE_CHANGED = "SOURCE_CHANGED"
1212
SINK_CHANGED = "SINK_CHANGED"
13+
MODEL_CHANGED = "MODEL_CHANGED"
1314
PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED = "PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED"
1415
PIPELINE_STATUS_CHANGED = "PIPELINE_STATUS_CHANGED"
1516
INFERENCE_DEVICE_CHANGED = "INFERENCE_DEVICE_CHANGED"
@@ -51,6 +52,9 @@ def _should_notify_source(self, event_type: EventType) -> bool:
5152
def _should_notify_sink(self, event_type: EventType) -> bool:
5253
return event_type in (EventType.SINK_CHANGED, EventType.PIPELINE_STATUS_CHANGED)
5354

55+
def _should_notify_model(self, event_type: EventType) -> bool:
56+
return event_type in (EventType.MODEL_CHANGED, EventType.PIPELINE_STATUS_CHANGED)
57+
5458
def emit_event(self, event_type: EventType) -> None:
5559
super().emit_event(event_type)
5660

@@ -60,5 +64,5 @@ def emit_event(self, event_type: EventType) -> None:
6064
if self._should_notify_sink(event_type):
6165
self._notify_all(self._sink_changed_condition)
6266

63-
if event_type == EventType.PIPELINE_STATUS_CHANGED and self._model_reload_event:
67+
if self._should_notify_model(event_type) and self._model_reload_event:
6468
self._model_reload_event.set()

application/backend/app/services/pipeline_metrics_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_pipeline_metrics(self, pipeline_id: UUID, time_window: int = 60) -> Pipe
3636

3737
# Get actual latency measurements from the metrics service
3838
latency_samples = self._metrics_service.get_latency_measurements(
39-
model_id=pipeline.model_revision_id, # type: ignore[arg-type] # model is always there for running pipeline
39+
model_id=pipeline.model_id, # type: ignore[arg-type] # model is always there for running pipeline
4040
time_window=time_window,
4141
)
4242

@@ -55,7 +55,7 @@ def get_pipeline_metrics(self, pipeline_id: UUID, time_window: int = 60) -> Pipe
5555

5656
# Get throughput measurements from the metrics service
5757
total_requests, throughput_data = self._metrics_service.get_throughput_measurements(
58-
model_id=pipeline.model_revision_id, # type: ignore[arg-type]
58+
model_id=pipeline.model_id, # type: ignore[arg-type]
5959
time_window=time_window,
6060
)
6161
if total_requests:

application/backend/app/services/pipeline_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ def is_running(self, project_id: UUID) -> bool:
5151
def update_pipeline(self, project_id: UUID, partial_config: dict) -> Pipeline:
5252
"""Update an existing pipeline."""
5353
pipeline = self.get_pipeline_by_id(project_id)
54-
to_update = type(pipeline).model_validate(pipeline.model_copy(update=partial_config))
54+
base = pipeline.model_dump()
55+
to_update = type(pipeline).model_validate({**base, **partial_config})
5556
pipeline_repo = PipelineRepository(self._db_session)
5657
to_update_db = PipelineDB(
5758
project_id=str(to_update.project_id),
5859
source_id=str(to_update.source_id) if to_update.source_id else None,
5960
sink_id=str(to_update.sink_id) if to_update.sink_id else None,
60-
model_revision_id=str(to_update.model_revision_id) if to_update.model_revision_id else None,
61+
model_revision_id=str(to_update.model_id) if to_update.model_id else None,
6162
is_running=to_update.status.as_bool,
6263
data_collection_policies=[obj.model_dump() for obj in to_update.data_collection_policies],
6364
device=to_update.device,
@@ -74,6 +75,8 @@ def update_pipeline(self, project_id: UUID, partial_config: dict) -> Pipeline:
7475
self._event_bus.emit_event(EventType.PIPELINE_DATASET_COLLECTION_POLICIES_CHANGED)
7576
if pipeline.device != updated.device:
7677
self._event_bus.emit_event(EventType.INFERENCE_DEVICE_CHANGED)
78+
if pipeline.model_id != updated.model_revision.id: # type: ignore[union-attr] # model_revision is always there for running pipeline
79+
self._event_bus.emit_event(EventType.MODEL_CHANGED)
7780
elif pipeline.status != updated.status:
7881
# If the pipeline is being activated or stopped
7982
self._event_bus.emit_event(EventType.PIPELINE_STATUS_CHANGED)

application/backend/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def fxt_running_pipeline(fxt_webcam_source, fxt_mqtt_sink, fxt_model) -> Pipelin
7777
project_id=uuid4(),
7878
source_id=fxt_webcam_source.id,
7979
sink_id=fxt_mqtt_sink.id,
80-
model_revision_id=fxt_model.id,
80+
model_id=fxt_model.id,
8181
status=PipelineStatus.RUNNING,
8282
)
8383

application/backend/tests/integration/services/test_dataset_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def test_create_dataset_item(
493493
data=image,
494494
user_reviewed=user_reviewed,
495495
source_id=pipeline.source_id if use_pipeline_source else None,
496-
prediction_model_id=pipeline.model_revision_id if use_pipeline_model else None,
496+
prediction_model_id=pipeline.model_id if use_pipeline_model else None,
497497
annotations=fxt_annotations(label_id) if not user_reviewed else None,
498498
)
499499

@@ -515,7 +515,7 @@ def test_create_dataset_item(
515515
else:
516516
assert dataset_item.source_id is None
517517
if use_pipeline_model:
518-
assert dataset_item.prediction_model_id == str(pipeline.model_revision_id)
518+
assert dataset_item.prediction_model_id == str(pipeline.model_id)
519519
else:
520520
assert dataset_item.prediction_model_id is None
521521
if not user_reviewed:

application/backend/tests/integration/services/test_pipeline_service.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_create_pipeline(self, fxt_pipeline_service, fxt_project_id, fxt_db_proj
7272
and pipeline.source is None
7373
and pipeline.sink_id is None
7474
and pipeline.sink is None
75-
and pipeline.model_revision_id is None
75+
and pipeline.model_id is None
7676
and pipeline.model_revision is None
7777
and pipeline.status == PipelineStatus.IDLE
7878
and pipeline.data_collection_policies == []
@@ -91,7 +91,7 @@ def test_get_pipeline(self, fxt_pipeline_service, fxt_project_id, fxt_project_wi
9191
assert pipeline.status == PipelineStatus.IDLE
9292
assert pipeline.sink.name == db_pipeline.sink.name
9393
assert pipeline.source.name == db_pipeline.source.name
94-
assert str(pipeline.model_revision_id) == db_pipeline.model_revision_id
94+
assert str(pipeline.model_id) == db_pipeline.model_revision_id
9595
assert pipeline.data_collection_policies == [FixedRateDataCollectionPolicy(rate=0.1)]
9696

9797
def test_get_active_pipeline(self, fxt_pipeline_service, fxt_project_with_pipeline, db_session):
@@ -142,6 +142,27 @@ def test_reconfigure_running_pipeline(
142142
assert str(getattr(updated, pipeline_attr)) == item_id
143143
assert str(getattr(updated, pipeline_attr)) == getattr(db_updated, pipeline_attr)
144144

145+
@pytest.mark.parametrize("model_attr", ["model_id", "model_revision_id"])
146+
def test_switch_model(
147+
self,
148+
model_attr,
149+
fxt_project_with_pipeline,
150+
fxt_db_models,
151+
fxt_pipeline_service,
152+
fxt_event_bus,
153+
db_session,
154+
):
155+
"""Test updating a pipeline by ID."""
156+
_, db_pipeline = fxt_project_with_pipeline(is_running=True)
157+
158+
model_id = fxt_db_models[1].id
159+
updated = fxt_pipeline_service.update_pipeline(db_pipeline.project_id, {model_attr: model_id})
160+
161+
fxt_event_bus.emit_event.assert_called_once_with(EventType.MODEL_CHANGED)
162+
db_updated = db_session.get(PipelineDB, db_pipeline.project_id)
163+
assert str(updated.model_id) == model_id
164+
assert str(updated.model_id) == db_updated.model_revision_id
165+
145166
@pytest.mark.parametrize("pipeline_status", [PipelineStatus.IDLE, PipelineStatus.RUNNING])
146167
def test_enable_disable_pipeline(
147168
self,

application/backend/tests/unit/services/event/test_event_bus.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def test_sink_changed(self, fxt_event_bus: EventBusFactory) -> None:
6969
notified = sink_changed_condition.acquire()
7070
assert notified
7171

72+
def test_model_changed(self, fxt_event_bus: EventBusFactory) -> None:
73+
"""Test model changed"""
74+
handler = MagicMock(spec=Callable)
75+
model_reload_event = mp.Event()
76+
event_bus = fxt_event_bus(None, None, model_reload_event)
77+
event_bus.subscribe(event_types=[EventType.MODEL_CHANGED], handler=handler)
78+
79+
event_bus.emit_event(EventType.MODEL_CHANGED)
80+
81+
handler.assert_called_once_with()
82+
assert model_reload_event.is_set()
83+
7284
def test_pipeline_dataset_collection_policies_changed(self, fxt_event_bus: EventBusFactory) -> None:
7385
"""Test pipeline dataset collection policies changed"""
7486
handler = MagicMock(spec=Callable)

application/docs/models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ The lifecycle of a model revision in Geti Tune consists of several key stages:
6262

6363
When users want to create a new model revision, they can choose any architecture compatible with the task, as well as
6464
the base model to fine-tune from. Such weights can be either from a public pre-trained model or from an existing
65-
model revision in the same project. In the latter case, an parent-child relationship is established between the two
65+
model revision in the same project. In the latter case, a parent-child relationship is established between the two
6666
model revisions; these links form a versioning chain that allows users to track the evolution of models over time.
6767
Revisions trained from scratch, namely from the pre-trained weights, do not have a parent model. A model revision may
6868
be used as a base for multiple new model revisions, allowing users to experiment with different training configurations.

0 commit comments

Comments
 (0)