@@ -72,7 +72,7 @@ def test_create_pipeline(self, fxt_pipeline_service, fxt_project_id, fxt_db_proj
7272 and pipeline .source is None
7373 and pipeline .sink_id is None
7474 and pipeline .sink is None
75- and pipeline .model_revision_id is None
75+ and pipeline .model_id is None
7676 and pipeline .model_revision is None
7777 and pipeline .status == PipelineStatus .IDLE
7878 and pipeline .data_collection_policies == []
@@ -91,7 +91,7 @@ def test_get_pipeline(self, fxt_pipeline_service, fxt_project_id, fxt_project_wi
9191 assert pipeline .status == PipelineStatus .IDLE
9292 assert pipeline .sink .name == db_pipeline .sink .name
9393 assert pipeline .source .name == db_pipeline .source .name
94- assert str (pipeline .model_revision_id ) == db_pipeline .model_revision_id
94+ assert str (pipeline .model_id ) == db_pipeline .model_revision_id
9595 assert pipeline .data_collection_policies == [FixedRateDataCollectionPolicy (rate = 0.1 )]
9696
9797 def test_get_active_pipeline (self , fxt_pipeline_service , fxt_project_with_pipeline , db_session ):
@@ -142,6 +142,27 @@ def test_reconfigure_running_pipeline(
142142 assert str (getattr (updated , pipeline_attr )) == item_id
143143 assert str (getattr (updated , pipeline_attr )) == getattr (db_updated , pipeline_attr )
144144
145+ @pytest .mark .parametrize ("model_attr" , ["model_id" , "model_revision_id" ])
146+ def test_switch_model (
147+ self ,
148+ model_attr ,
149+ fxt_project_with_pipeline ,
150+ fxt_db_models ,
151+ fxt_pipeline_service ,
152+ fxt_event_bus ,
153+ db_session ,
154+ ):
155+ """Test updating a pipeline by ID."""
156+ _ , db_pipeline = fxt_project_with_pipeline (is_running = True )
157+
158+ model_id = fxt_db_models [1 ].id
159+ updated = fxt_pipeline_service .update_pipeline (db_pipeline .project_id , {model_attr : model_id })
160+
161+ fxt_event_bus .emit_event .assert_called_once_with (EventType .MODEL_CHANGED )
162+ db_updated = db_session .get (PipelineDB , db_pipeline .project_id )
163+ assert str (updated .model_id ) == model_id
164+ assert str (updated .model_id ) == db_updated .model_revision_id
165+
145166 @pytest .mark .parametrize ("pipeline_status" , [PipelineStatus .IDLE , PipelineStatus .RUNNING ])
146167 def test_enable_disable_pipeline (
147168 self ,
0 commit comments