Skip to content

Commit 281653c

Browse files
committed
fix: Model server override logic
1 parent 382fde1 commit 281653c

File tree

2 files changed

+168
-27
lines changed

2 files changed

+168
-27
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@
8888

8989
logger = logging.getLogger(__name__)
9090

91+
# Any new server type should be added here
9192
supported_model_server = {
9293
ModelServer.TORCHSERVE,
9394
ModelServer.TRITON,
9495
ModelServer.DJL_SERVING,
9596
ModelServer.TENSORFLOW_SERVING,
97+
ModelServer.MMS,
98+
ModelServer.TGI,
99+
ModelServer.TEI,
96100
}
97101

98102

@@ -281,31 +285,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
281285
},
282286
)
283287

284-
def _build_validations(self):
285-
"""Placeholder docstring"""
286-
# TODO: Beta validations - remove after the launch
287-
if self.mode == Mode.IN_PROCESS:
288-
raise ValueError("IN_PROCESS mode is not supported yet!")
289-
290-
if self.inference_spec and self.model:
291-
raise ValueError("Cannot have both the Model and Inference spec in the builder")
292-
293-
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
294-
raise ValueError(
295-
"Model_server must be set when non-first-party image_uri is set. "
296-
+ "Supported model servers: %s" % supported_model_server
297-
)
298-
299-
# Set TorchServe as default model server
300-
if not self.model_server:
301-
self.model_server = ModelServer.TORCHSERVE
302-
303-
if self.model_server not in supported_model_server:
304-
raise ValueError(
305-
"%s is not supported yet! Supported model servers: %s"
306-
% (self.model_server, supported_model_server)
307-
)
308-
309288
def _save_model_inference_spec(self):
310289
"""Placeholder docstring"""
311290
# check if path exists and create if not
@@ -748,6 +727,11 @@ def build( # pylint: disable=R0911
748727
self._initialize_for_mlflow()
749728
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
750729

730+
self._build_validations()
731+
732+
if self.model_server:
733+
return self._build_for_model_server()
734+
751735
if isinstance(self.model, str):
752736
model_task = None
753737
if self.model_metadata:
@@ -779,7 +763,37 @@ def build( # pylint: disable=R0911
779763
else:
780764
return self._build_for_transformers()
781765

782-
self._build_validations()
766+
# Set TorchServe as default model server
767+
if not self.model_server:
768+
self.model_server = ModelServer.TORCHSERVE
769+
return self._build_for_torchserve()
770+
771+
raise ValueError("%s model server is not supported" % self.model_server)
772+
773+
def _build_validations(self):
774+
"""Validations needed for model server overrides, or auto-detectection or fallback"""
775+
if self.mode == Mode.IN_PROCESS:
776+
raise ValueError("IN_PROCESS mode is not supported yet!")
777+
778+
if self.inference_spec and self.model:
779+
raise ValueError("Cannot have both the Model and Inference spec in the builder")
780+
781+
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
782+
raise ValueError(
783+
"Model_server must be set when non-first-party image_uri is set. "
784+
+ "Supported model servers: %s" % supported_model_server
785+
)
786+
787+
def _build_for_model_server(self): # pylint: disable=R0911, R1710
788+
"""Model server overrides"""
789+
if self.model_server not in supported_model_server:
790+
raise ValueError(
791+
"%s is not supported yet! Supported model servers: %s"
792+
% (self.model_server, supported_model_server)
793+
)
794+
795+
if not self.model and not self.model_metadata:
796+
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
783797

784798
if self.model_server == ModelServer.TORCHSERVE:
785799
return self._build_for_torchserve()
@@ -790,7 +804,17 @@ def build( # pylint: disable=R0911
790804
if self.model_server == ModelServer.TENSORFLOW_SERVING:
791805
return self._build_for_tensorflow_serving()
792806

793-
raise ValueError("%s model server is not supported" % self.model_server)
807+
if self.model_server == ModelServer.DJL_SERVING:
808+
return self._build_for_djl()
809+
810+
if self.model_server == ModelServer.TEI:
811+
return self._build_for_tei()
812+
813+
if self.model_server == ModelServer.TGI:
814+
return self._build_for_tgi()
815+
816+
if self.model_server == ModelServer.MMS:
817+
return self._build_for_transformers()
794818

795819
def save(
796820
self,

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
ModelServer.TRITON,
5555
ModelServer.DJL_SERVING,
5656
ModelServer.TENSORFLOW_SERVING,
57+
ModelServer.MMS,
58+
ModelServer.TGI,
59+
ModelServer.TEI,
5760
}
5861

5962
mock_session = MagicMock()
@@ -124,6 +127,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
124127
mock_session,
125128
)
126129

130+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
131+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
132+
def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings):
133+
mock_setting_object = mock_serve_settings.return_value
134+
mock_setting_object.role_arn = mock_role_arn
135+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
136+
137+
builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt")
138+
builder.build(sagemaker_session=mock_session)
139+
140+
mock_build_for_djl.assert_called_once()
141+
142+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
143+
def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings):
144+
builder = ModelBuilder(
145+
model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None
146+
)
147+
self.assertRaisesRegex(
148+
Exception,
149+
"Missing required parameter `model` or 'ml_flow' path",
150+
builder.build,
151+
Mode.SAGEMAKER_ENDPOINT,
152+
mock_role_arn,
153+
mock_session,
154+
)
155+
156+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
157+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
158+
def test_model_server_override_torchserve_with_model(
159+
self, mock_build_for_ts, mock_serve_settings
160+
):
161+
mock_setting_object = mock_serve_settings.return_value
162+
mock_setting_object.role_arn = mock_role_arn
163+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
164+
165+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt")
166+
builder.build(sagemaker_session=mock_session)
167+
168+
mock_build_for_ts.assert_called_once()
169+
170+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
171+
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
172+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
173+
self.assertRaisesRegex(
174+
Exception,
175+
"Missing required parameter `model` or 'ml_flow' path",
176+
builder.build,
177+
Mode.SAGEMAKER_ENDPOINT,
178+
mock_role_arn,
179+
mock_session,
180+
)
181+
182+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
183+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton")
184+
def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings):
185+
mock_setting_object = mock_serve_settings.return_value
186+
mock_setting_object.role_arn = mock_role_arn
187+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
188+
189+
builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt")
190+
builder.build(sagemaker_session=mock_session)
191+
192+
mock_build_for_ts.assert_called_once()
193+
194+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
195+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving")
196+
def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings):
197+
mock_setting_object = mock_serve_settings.return_value
198+
mock_setting_object.role_arn = mock_role_arn
199+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
200+
201+
builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt")
202+
builder.build(sagemaker_session=mock_session)
203+
204+
mock_build_for_ts.assert_called_once()
205+
206+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
207+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
208+
def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings):
209+
mock_setting_object = mock_serve_settings.return_value
210+
mock_setting_object.role_arn = mock_role_arn
211+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
212+
213+
builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt")
214+
builder.build(sagemaker_session=mock_session)
215+
216+
mock_build_for_ts.assert_called_once()
217+
218+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
219+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi")
220+
def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings):
221+
mock_setting_object = mock_serve_settings.return_value
222+
mock_setting_object.role_arn = mock_role_arn
223+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
224+
225+
builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt")
226+
builder.build(sagemaker_session=mock_session)
227+
228+
mock_build_for_ts.assert_called_once()
229+
230+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
231+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
232+
def test_model_server_override_transformers_with_model(
233+
self, mock_build_for_ts, mock_serve_settings
234+
):
235+
mock_setting_object = mock_serve_settings.return_value
236+
mock_setting_object.role_arn = mock_role_arn
237+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
238+
239+
builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt")
240+
builder.build(sagemaker_session=mock_session)
241+
242+
mock_build_for_ts.assert_called_once()
243+
127244
@patch("os.makedirs", Mock())
128245
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
129246
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")

0 commit comments

Comments
 (0)