|
54 | 54 | ModelServer.TRITON,
|
55 | 55 | ModelServer.DJL_SERVING,
|
56 | 56 | ModelServer.TENSORFLOW_SERVING,
|
| 57 | + ModelServer.MMS, |
| 58 | + ModelServer.TGI, |
| 59 | + ModelServer.TEI, |
57 | 60 | }
|
58 | 61 |
|
59 | 62 | mock_session = MagicMock()
|
@@ -124,6 +127,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
|
124 | 127 | mock_session,
|
125 | 128 | )
|
126 | 129 |
|
| 130 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 131 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") |
| 132 | + def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings): |
| 133 | + mock_setting_object = mock_serve_settings.return_value |
| 134 | + mock_setting_object.role_arn = mock_role_arn |
| 135 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 136 | + |
| 137 | + builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt") |
| 138 | + builder.build(sagemaker_session=mock_session) |
| 139 | + |
| 140 | + mock_build_for_djl.assert_called_once() |
| 141 | + |
| 142 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 143 | + def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings): |
| 144 | + builder = ModelBuilder( |
| 145 | + model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None |
| 146 | + ) |
| 147 | + self.assertRaisesRegex( |
| 148 | + Exception, |
| 149 | + "Missing required parameter `model` or 'ml_flow' path", |
| 150 | + builder.build, |
| 151 | + Mode.SAGEMAKER_ENDPOINT, |
| 152 | + mock_role_arn, |
| 153 | + mock_session, |
| 154 | + ) |
| 155 | + |
| 156 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 157 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve") |
| 158 | + def test_model_server_override_torchserve_with_model( |
| 159 | + self, mock_build_for_ts, mock_serve_settings |
| 160 | + ): |
| 161 | + mock_setting_object = mock_serve_settings.return_value |
| 162 | + mock_setting_object.role_arn = mock_role_arn |
| 163 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 164 | + |
| 165 | + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt") |
| 166 | + builder.build(sagemaker_session=mock_session) |
| 167 | + |
| 168 | + mock_build_for_ts.assert_called_once() |
| 169 | + |
| 170 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 171 | + def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings): |
| 172 | + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE) |
| 173 | + self.assertRaisesRegex( |
| 174 | + Exception, |
| 175 | + "Missing required parameter `model` or 'ml_flow' path", |
| 176 | + builder.build, |
| 177 | + Mode.SAGEMAKER_ENDPOINT, |
| 178 | + mock_role_arn, |
| 179 | + mock_session, |
| 180 | + ) |
| 181 | + |
| 182 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 183 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton") |
| 184 | + def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings): |
| 185 | + mock_setting_object = mock_serve_settings.return_value |
| 186 | + mock_setting_object.role_arn = mock_role_arn |
| 187 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 188 | + |
| 189 | + builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt") |
| 190 | + builder.build(sagemaker_session=mock_session) |
| 191 | + |
| 192 | + mock_build_for_ts.assert_called_once() |
| 193 | + |
| 194 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 195 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving") |
| 196 | + def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings): |
| 197 | + mock_setting_object = mock_serve_settings.return_value |
| 198 | + mock_setting_object.role_arn = mock_role_arn |
| 199 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 200 | + |
| 201 | + builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt") |
| 202 | + builder.build(sagemaker_session=mock_session) |
| 203 | + |
| 204 | + mock_build_for_ts.assert_called_once() |
| 205 | + |
| 206 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 207 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei") |
| 208 | + def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings): |
| 209 | + mock_setting_object = mock_serve_settings.return_value |
| 210 | + mock_setting_object.role_arn = mock_role_arn |
| 211 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 212 | + |
| 213 | + builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt") |
| 214 | + builder.build(sagemaker_session=mock_session) |
| 215 | + |
| 216 | + mock_build_for_ts.assert_called_once() |
| 217 | + |
| 218 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 219 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") |
| 220 | + def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings): |
| 221 | + mock_setting_object = mock_serve_settings.return_value |
| 222 | + mock_setting_object.role_arn = mock_role_arn |
| 223 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 224 | + |
| 225 | + builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt") |
| 226 | + builder.build(sagemaker_session=mock_session) |
| 227 | + |
| 228 | + mock_build_for_ts.assert_called_once() |
| 229 | + |
| 230 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 231 | + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") |
| 232 | + def test_model_server_override_transformers_with_model( |
| 233 | + self, mock_build_for_ts, mock_serve_settings |
| 234 | + ): |
| 235 | + mock_setting_object = mock_serve_settings.return_value |
| 236 | + mock_setting_object.role_arn = mock_role_arn |
| 237 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 238 | + |
| 239 | + builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt") |
| 240 | + builder.build(sagemaker_session=mock_session) |
| 241 | + |
| 242 | + mock_build_for_ts.assert_called_once() |
| 243 | + |
127 | 244 | @patch("os.makedirs", Mock())
|
128 | 245 | @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
|
129 | 246 | @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")
|
|
0 commit comments