From bc84f20f9bb5a06e3e782463e823bdb75bc6b0eb Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 10 Jun 2024 20:44:12 +0000 Subject: [PATCH 1/2] fix: estimator.deploy not respecting instance type --- src/sagemaker/jumpstart/factory/estimator.py | 6 +++++- tests/unit/sagemaker/jumpstart/estimator/test_estimator.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 70b205bc74..60861ea4ac 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -322,7 +322,11 @@ def get_deploy_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, - instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, + instance_type=( + model_deploy_kwargs.instance_type + if training_instance_type is None or instance_type is not None + else None + ), region=region, image_uri=image_uri, source_dir=source_dir, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 2e8dc1e9a2..96b00793b8 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1532,6 +1532,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta estimator.deploy(image_uri="blah") assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" + estimator.deploy(image_uri="blah", instance_type="ml.quantum.large") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.quantum.large" + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( From f895afbea5a8ec342e76c0d8183ef82f01f1c309 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 10 Jun 2024 20:51:08 +0000 Subject: [PATCH 2/2] chore: add inline comment about using user supplied instance type --- src/sagemaker/jumpstart/factory/estimator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 60861ea4ac..4a0b33a271 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -324,7 +324,8 @@ def get_deploy_kwargs( model_version=model_version, instance_type=( model_deploy_kwargs.instance_type - if training_instance_type is None or instance_type is not None + if training_instance_type is None + or instance_type is not None # always use supplied inference instance type else None ), region=region,