Skip to content

Commit c63617d

Browse files
committed
Update: Pull latest tei container for sentence similiarity models on HuggingFace hub
1 parent 65cc586 commit c63617d

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_get_nb_instance,
2424
)
2525
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
26-
from sagemaker.huggingface import HuggingFaceModel
26+
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
2727
from sagemaker.serve.model_server.multi_model_server.prepare import (
2828
_create_dir_structure,
2929
)
@@ -99,7 +99,22 @@ def _create_transformers_model(self) -> Type[Model]:
9999
if hf_model_md is None:
100100
raise ValueError("Could not fetch HF metadata")
101101

102-
if "pytorch" in hf_model_md.get("tags"):
102+
model_task = hf_model_md.get("pipeline_tag")
103+
104+
if model_task == "sentence-similarity" and not self.image_uri:
105+
self.image_uri = \
106+
get_huggingface_llm_image_uri("huggingface-tei", session=self.sagemaker_session)
107+
108+
logger.info("Auto detected %s. Proceeding with the the deployment.", self.image_uri)
109+
110+
pysdk_model = HuggingFaceModel(
111+
env=self.env_vars,
112+
role=self.role_arn,
113+
sagemaker_session=self.sagemaker_session,
114+
image_uri=self.image_uri,
115+
vpc_config=self.vpc_config,
116+
)
117+
elif "pytorch" in hf_model_md.get("tags"):
103118
self.pytorch_version = self._get_supported_version(
104119
hf_config, base_hf_version, "pytorch"
105120
)

tests/unit/sagemaker/serve/builder/test_transformers_builder.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
110110
return_value="ml.g5.24xlarge",
111111
)
112112
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
113-
def test_image_uri(
113+
def test_image_uri_override(
114114
self,
115115
mock_get_nb_instance,
116116
mock_telemetry,
@@ -144,3 +144,52 @@ def test_image_uri(
144144

145145
with self.assertRaises(ValueError) as _:
146146
model.deploy(mode=Mode.IN_PROCESS)
147+
148+
@patch(
149+
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
150+
return_value="ml.g5.24xlarge",
151+
)
152+
@patch(
153+
"sagemaker.huggingface.llm_utils.get_huggingface_model_metadata",
154+
return_value="sentence-similarity",
155+
)
156+
@patch(
157+
"from sagemaker.huggingface.get_huggingface_llm_image_uri",
158+
return_value=MOCK_IMAGE_CONFIG
159+
)
160+
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
161+
def test_sentence_similarity_support(
162+
self,
163+
mock_get_nb_instance,
164+
mock_task,
165+
mock_image,
166+
mock_telemetry,
167+
):
168+
builder = ModelBuilder(
169+
model=mock_model_id,
170+
schema_builder=mock_schema_builder,
171+
mode=Mode.LOCAL_CONTAINER,
172+
)
173+
174+
builder._prepare_for_mode = MagicMock()
175+
builder._prepare_for_mode.side_effect = None
176+
177+
model = builder.build()
178+
builder.serve_settings.telemetry_opt_out = True
179+
180+
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
181+
predictor = model.deploy(model_data_download_timeout=1800)
182+
183+
assert builder.image_uri == MOCK_IMAGE_CONFIG
184+
assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
185+
assert isinstance(predictor, TransformersLocalModePredictor)
186+
187+
assert builder.nb_instance_type == "ml.g5.24xlarge"
188+
189+
builder._original_deploy = MagicMock()
190+
builder._prepare_for_mode.return_value = (None, {})
191+
predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
192+
assert "HF_MODEL_ID" in model.env
193+
194+
with self.assertRaises(ValueError) as _:
195+
model.deploy(mode=Mode.IN_PROCESS)

0 commit comments

Comments
 (0)