diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 70b205bc74..4a0b33a271 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -322,7 +322,12 @@ def get_deploy_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, - instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, + instance_type=( + model_deploy_kwargs.instance_type + if training_instance_type is None + or instance_type is not None # always use supplied inference instance type + else None + ), region=region, image_uri=image_uri, source_dir=source_dir, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 2e8dc1e9a2..96b00793b8 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1532,6 +1532,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta estimator.deploy(image_uri="blah") assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" + estimator.deploy(image_uri="blah", instance_type="ml.quantum.large") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.quantum.large" + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch(