diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 4a0b33a271..d6cea3cf09 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -258,6 +258,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -302,6 +303,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index b4bfd8a348..380afbb433 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -542,6 +542,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -576,6 +577,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 205b3bb08d..d2a09345a1 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -36,12 +36,13 @@ get_init_kwargs, get_register_kwargs, ) +from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, ) -from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.model_card import ( ModelCard, @@ -406,6 +407,45 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) + @classmethod + def attach( + cls, + endpoint_name: str, + inference_component_name: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) -> "JumpStartModel": + """Attaches a JumpStartModel object to an existing SageMaker Endpoint. + + The model id, version (and inference component name) can be inferred from the tags. + """ + + inferred_model_id = inferred_model_version = inferred_inference_component_name = None + + if inference_component_name is None or model_id is None or model_version is None: + inferred_model_id, inferred_model_version, inferred_inference_component_name = ( + get_model_id_version_from_endpoint( + endpoint_name=endpoint_name, + inference_component_name=inference_component_name, + sagemaker_session=sagemaker_session, + ) + ) + + model_id = model_id or inferred_model_id + model_version = model_version or inferred_model_version or "*" + inference_component_name = inference_component_name or inferred_inference_component_name + + model = JumpStartModel( + model_id=model_id, + model_version=model_version, + sagemaker_session=sagemaker_session, + ) + model.endpoint_name = endpoint_name + model.inference_component_name = inference_component_name + + return model + def _create_sagemaker_model( self, instance_type=None, @@ -484,6 +524,7 @@ def deploy( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = True, @@ -614,6 +655,7 @@ def deploy( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5754704632..88e25f8a94 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1596,6 +1596,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "deserializer", "accelerator_type", "endpoint_name", + "inference_component_name", "tags", "kms_key", "wait", @@ -1641,6 +1642,7 @@ def __init__( deserializer: Optional[Any] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -1674,6 +1676,7 @@ def __init__( self.deserializer = deserializer self.accelerator_type = accelerator_type self.endpoint_name = endpoint_name + self.inference_component_name = inference_component_name self.tags = format_tags(tags) self.kms_key = kms_key self.wait = wait diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5c5156c84a..b6848800dd 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -358,6 +358,7 @@ def __init__( sagemaker_config=self._sagemaker_config, ) self.endpoint_name = None + self.inference_component_name = None self._is_compiled_model = False self._compilation_job_name = None self._is_edge_packaged_model = False @@ -405,6 +406,16 @@ def __init__( self.response_types = None self.accept_eula = None + @classmethod + def attach( + cls, + endpoint_name: str, + inference_component_name: Optional[str] = None, + sagemaker_session=None, + ) -> "Model": + """Attaches a Model object to an existing SageMaker Endpoint.""" + raise NotImplementedError + @runnable_by_pipeline def register( self, @@ -1318,6 +1329,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, managed_instance_scaling: Optional[str] = None, + inference_component_name=None, routing_config: Optional[Dict[str, Any]] = None, **kwargs, ): @@ -1602,11 +1614,15 @@ def deploy( "ComputeResourceRequirements": resources.get_compute_resource_requirements(), } runtime_config = {"CopyCount": resources.copy_count} - inference_component_name = unique_name_from_base(self.name) + self.inference_component_name = ( + inference_component_name + or self.inference_component_name + or unique_name_from_base(self.name) + ) # [TODO]: Add endpoint_logging support self.sagemaker_session.create_inference_component( - inference_component_name=inference_component_name, + inference_component_name=self.inference_component_name, endpoint_name=self.endpoint_name, variant_name="AllTraffic", # default variant name specification=inference_component_spec, @@ -1619,7 +1635,7 @@ def deploy( predictor = self.predictor_cls( self.endpoint_name, self.sagemaker_session, - component_name=inference_component_name, + component_name=self.inference_component_name, ) if serializer: predictor.serializer = serializer diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 5205765e2f..96ee82883e 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -219,6 +219,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): assert response is not None + model = JumpStartModel.attach(predictor.endpoint_name, sagemaker_session=get_sm_session()) + assert model.model_id == model_id + assert model.endpoint_name == predictor.endpoint_name + assert model.inference_component_name == predictor.component_name + @mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") def test_instatiating_model(mock_warning_logger, setup): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 140b839937..58c08f5b3d 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1324,6 +1324,54 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) + @mock.patch("sagemaker.jumpstart.model.get_model_id_version_from_endpoint") + @mock.patch("sagemaker.jumpstart.model.JumpStartModel.__init__") + def test_attach( + self, + mock_js_model_init, + mock_get_model_id_version_from_endpoint, + ): + mock_js_model_init.return_value = None + mock_get_model_id_version_from_endpoint.return_value = "model-id", "model-version", None + val = JumpStartModel.attach("some-endpoint") + mock_get_model_id_version_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + mock_js_model_init.assert_called_once_with( + model_id="model-id", + model_version="model-version", + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + assert isinstance(val, JumpStartModel) + + mock_get_model_id_version_from_endpoint.reset_mock() + JumpStartModel.attach("some-endpoint", model_id="some-id") + mock_get_model_id_version_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + + mock_get_model_id_version_from_endpoint.reset_mock() + JumpStartModel.attach("some-endpoint", model_id="some-id", model_version="some-version") + mock_get_model_id_version_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + + # providing model id, version, and ic name should bypass check with endpoint tags + mock_get_model_id_version_from_endpoint.reset_mock() + JumpStartModel.attach( + "some-endpoint", + model_id="some-id", + model_version="some-version", + inference_component_name="some-ic-name", + ) + mock_get_model_id_version_from_endpoint.assert_not_called() + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"