Skip to content

Commit 734a0b0

Browse files
samrudsbenieric
authored andcommitted
fix: Model server override logic (aws#4733)
* fix: Model server override logic * Fix formatting --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 637ad2c commit 734a0b0

File tree

2 files changed

+177
-32
lines changed

2 files changed

+177
-32
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,16 @@
9595

9696
logger = logging.getLogger(__name__)
9797

98-
supported_model_server = {
98+
# Any new server type should be added here
99+
supported_model_servers = {
99100
ModelServer.TORCHSERVE,
100101
ModelServer.TRITON,
101102
ModelServer.DJL_SERVING,
102103
ModelServer.FASTAPI
103104
ModelServer.TENSORFLOW_SERVING,
105+
ModelServer.MMS,
106+
ModelServer.TGI,
107+
ModelServer.TEI,
104108
}
105109

106110

@@ -290,31 +294,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
290294
},
291295
)
292296

293-
def _build_validations(self):
294-
"""Placeholder docstring"""
295-
# TODO: Beta validations - remove after the launch
296-
if self.mode == Mode.IN_PROCESS:
297-
raise ValueError("IN_PROCESS mode is not supported yet!")
298-
299-
if self.inference_spec and self.model:
300-
raise ValueError("Cannot have both the Model and Inference spec in the builder")
301-
302-
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
303-
raise ValueError(
304-
"Model_server must be set when non-first-party image_uri is set. "
305-
+ "Supported model servers: %s" % supported_model_server
306-
)
307-
308-
# Set TorchServe as default model server
309-
if not self.model_server:
310-
self.model_server = ModelServer.TORCHSERVE
311-
312-
if self.model_server not in supported_model_server:
313-
raise ValueError(
314-
"%s is not supported yet! Supported model servers: %s"
315-
% (self.model_server, supported_model_server)
316-
)
317-
318297
def _save_model_inference_spec(self):
319298
"""Placeholder docstring"""
320299
# check if path exists and create if not
@@ -841,6 +820,11 @@ def build( # pylint: disable=R0911
841820

842821
self._handle_mlflow_input()
843822

823+
self._build_validations()
824+
825+
if self.model_server:
826+
return self._build_for_model_server()
827+
844828
if isinstance(self.model, str):
845829
model_task = None
846830
if self.model_metadata:
@@ -872,7 +856,41 @@ def build( # pylint: disable=R0911
872856
else:
873857
return self._build_for_transformers()
874858

875-
self._build_validations()
859+
# Set TorchServe as default model server
860+
if not self.model_server:
861+
self.model_server = ModelServer.TORCHSERVE
862+
return self._build_for_torchserve()
863+
864+
raise ValueError("%s model server is not supported" % self.model_server)
865+
866+
def _build_validations(self):
867+
"""Validations needed for model server overrides, or auto-detection or fallback"""
868+
if self.mode == Mode.IN_PROCESS:
869+
raise ValueError("IN_PROCESS mode is not supported yet!")
870+
871+
if self.inference_spec and self.model:
872+
raise ValueError("Can only set one of the following: model, inference_spec.")
873+
874+
if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
875+
raise ValueError(
876+
"Model_server must be set when non-first-party image_uri is set. "
877+
+ "Supported model servers: %s" % supported_model_servers
878+
)
879+
880+
def _build_for_model_server(self): # pylint: disable=R0911, R1710
881+
"""Model server overrides"""
882+
if self.model_server not in supported_model_servers:
883+
raise ValueError(
884+
"%s is not supported yet! Supported model servers: %s"
885+
% (self.model_server, supported_model_servers)
886+
)
887+
888+
mlflow_path = None
889+
if self.model_metadata:
890+
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
891+
892+
if not self.model and not mlflow_path:
893+
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
876894

877895
from langchain_core.runnables import RunnableSerializable
878896

@@ -889,7 +907,17 @@ def build( # pylint: disable=R0911
889907
if self.model_server == ModelServer.TENSORFLOW_SERVING:
890908
return self._build_for_tensorflow_serving()
891909

892-
raise ValueError("%s model server is not supported" % self.model_server)
910+
if self.model_server == ModelServer.DJL_SERVING:
911+
return self._build_for_djl()
912+
913+
if self.model_server == ModelServer.TEI:
914+
return self._build_for_tei()
915+
916+
if self.model_server == ModelServer.TGI:
917+
return self._build_for_tgi()
918+
919+
if self.model_server == ModelServer.MMS:
920+
return self._build_for_transformers()
893921

894922
def save(
895923
self,

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@
5050
mock_secret_key = "mock_secret_key"
5151
mock_instance_type = "mock instance type"
5252

53-
supported_model_server = {
53+
supported_model_servers = {
5454
ModelServer.TORCHSERVE,
5555
ModelServer.TRITON,
5656
ModelServer.DJL_SERVING,
5757
ModelServer.TENSORFLOW_SERVING,
58+
ModelServer.MMS,
59+
ModelServer.TGI,
60+
ModelServer.TEI,
5861
}
5962

6063
mock_session = MagicMock()
@@ -78,7 +81,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
7881
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
7982
self.assertRaisesRegex(
8083
Exception,
81-
"Cannot have both the Model and Inference spec in the builder",
84+
"Can only set one of the following: model, inference_spec.",
8285
builder.build,
8386
Mode.SAGEMAKER_ENDPOINT,
8487
mock_role_arn,
@@ -91,7 +94,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
9194
self.assertRaisesRegex(
9295
Exception,
9396
"%s is not supported yet! Supported model servers: %s"
94-
% (builder.model_server, supported_model_server),
97+
% (builder.model_server, supported_model_servers),
9598
builder.build,
9699
Mode.SAGEMAKER_ENDPOINT,
97100
mock_role_arn,
@@ -104,7 +107,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104107
self.assertRaisesRegex(
105108
Exception,
106109
"Model_server must be set when non-first-party image_uri is set. "
107-
+ "Supported model servers: %s" % supported_model_server,
110+
+ "Supported model servers: %s" % supported_model_servers,
108111
builder.build,
109112
Mode.SAGEMAKER_ENDPOINT,
110113
mock_role_arn,
@@ -125,6 +128,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
125128
mock_session,
126129
)
127130

131+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
132+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
133+
def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings):
134+
mock_setting_object = mock_serve_settings.return_value
135+
mock_setting_object.role_arn = mock_role_arn
136+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
137+
138+
builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt")
139+
builder.build(sagemaker_session=mock_session)
140+
141+
mock_build_for_djl.assert_called_once()
142+
143+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
144+
def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings):
145+
builder = ModelBuilder(
146+
model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None
147+
)
148+
self.assertRaisesRegex(
149+
Exception,
150+
"Missing required parameter `model` or 'ml_flow' path",
151+
builder.build,
152+
Mode.SAGEMAKER_ENDPOINT,
153+
mock_role_arn,
154+
mock_session,
155+
)
156+
157+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
158+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
159+
def test_model_server_override_torchserve_with_model(
160+
self, mock_build_for_ts, mock_serve_settings
161+
):
162+
mock_setting_object = mock_serve_settings.return_value
163+
mock_setting_object.role_arn = mock_role_arn
164+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
165+
166+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt")
167+
builder.build(sagemaker_session=mock_session)
168+
169+
mock_build_for_ts.assert_called_once()
170+
171+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
172+
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
173+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
174+
self.assertRaisesRegex(
175+
Exception,
176+
"Missing required parameter `model` or 'ml_flow' path",
177+
builder.build,
178+
Mode.SAGEMAKER_ENDPOINT,
179+
mock_role_arn,
180+
mock_session,
181+
)
182+
183+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
184+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton")
185+
def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings):
186+
mock_setting_object = mock_serve_settings.return_value
187+
mock_setting_object.role_arn = mock_role_arn
188+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
189+
190+
builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt")
191+
builder.build(sagemaker_session=mock_session)
192+
193+
mock_build_for_ts.assert_called_once()
194+
195+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
196+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving")
197+
def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings):
198+
mock_setting_object = mock_serve_settings.return_value
199+
mock_setting_object.role_arn = mock_role_arn
200+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
201+
202+
builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt")
203+
builder.build(sagemaker_session=mock_session)
204+
205+
mock_build_for_ts.assert_called_once()
206+
207+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
208+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
209+
def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings):
210+
mock_setting_object = mock_serve_settings.return_value
211+
mock_setting_object.role_arn = mock_role_arn
212+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
213+
214+
builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt")
215+
builder.build(sagemaker_session=mock_session)
216+
217+
mock_build_for_ts.assert_called_once()
218+
219+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
220+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi")
221+
def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings):
222+
mock_setting_object = mock_serve_settings.return_value
223+
mock_setting_object.role_arn = mock_role_arn
224+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
225+
226+
builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt")
227+
builder.build(sagemaker_session=mock_session)
228+
229+
mock_build_for_ts.assert_called_once()
230+
231+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
232+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
233+
def test_model_server_override_transformers_with_model(
234+
self, mock_build_for_ts, mock_serve_settings
235+
):
236+
mock_setting_object = mock_serve_settings.return_value
237+
mock_setting_object.role_arn = mock_role_arn
238+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
239+
240+
builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt")
241+
builder.build(sagemaker_session=mock_session)
242+
243+
mock_build_for_ts.assert_called_once()
244+
128245
@patch("os.makedirs", Mock())
129246
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
130247
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")

0 commit comments

Comments
 (0)