Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ def get_deploy_kwargs(
model_id=model_id,
model_from_estimator=True,
model_version=model_version,
instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None,
instance_type=(
model_deploy_kwargs.instance_type
if training_instance_type is None or instance_type is not None
else None
),
region=region,
image_uri=image_uri,
source_dir=source_dir,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta
estimator.deploy(image_uri="blah")
assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge"

estimator.deploy(image_uri="blah", instance_type="ml.quantum.large")
assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.quantum.large"

@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
@mock.patch(
Expand Down