Skip to content

Commit 7de2c53

Browse files
committed
fix: register jumpstart models on model registry
1 parent 58bb448 commit 7de2c53

File tree

5 files changed

+28
-27
lines changed

5 files changed

+28
-27
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,6 @@ def register(
761761
source_uri: Optional[Union[str, PipelineVariable]] = None,
762762
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
763763
accept_eula: Optional[bool] = None,
764-
765764
):
766765
"""Creates a model package for creating SageMaker models or listing on Marketplace.
767766
@@ -820,8 +819,9 @@ def register(
820819
A `sagemaker.model.ModelPackage` instance.
821820
"""
822821

823-
if model_package_group_name is None and self.model_type is JumpStartModelType.PROPRIETARY:
822+
if model_package_group_name is None:
824823
model_package_group_name = self.model_id
824+
if self.model_type is JumpStartModelType.PROPRIETARY:
825825
source_uri = self.model_package_arn
826826

827827
register_kwargs = get_register_kwargs(

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2410,7 +2410,6 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
24102410
"model_version",
24112411
"hub_arn",
24122412
"sagemaker_session",
2413-
"model_type",
24142413
}
24152414

24162415
def __init__(

src/sagemaker/model.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def register(
450450
source_uri: Optional[Union[str, PipelineVariable]] = None,
451451
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
452452
accept_eula: Optional[bool] = None,
453+
model_type: Optional[JumpStartModelType] = None,
453454
):
454455
"""Creates a model package for creating SageMaker models or listing on Marketplace.
455456
@@ -517,22 +518,22 @@ def register(
517518

518519
if image_uri is not None:
519520
self.image_uri = image_uri
520-
if self.model_type is not JumpStartModelType.PROPRIETARY:
521-
if model_package_group_name is None and model_package_name is None:
522-
# If model package group and model package name is not set
523-
# then register to auto-generated model package group
524-
model_package_group_name = utils.base_name_from_image(
525-
self.image_uri, default_base_name=ModelPackage.__name__
526-
)
527-
if model_package_group_name is not None:
528-
container_def = self.prepare_container_def(accept_eula=accept_eula)
529-
container_def = update_container_with_inference_params(
530-
framework=framework,
531-
framework_version=framework_version,
532-
nearest_model_name=nearest_model_name,
533-
data_input_configuration=data_input_configuration,
534-
container_def=container_def,
535-
)
521+
522+
if model_package_group_name is None and model_package_name is None:
523+
# If model package group and model package name is not set
524+
# then register to auto-generated model package group
525+
model_package_group_name = utils.base_name_from_image(
526+
self.image_uri, default_base_name=ModelPackage.__name__
527+
)
528+
if model_package_group_name is not None:
529+
container_def = self.prepare_container_def(accept_eula=accept_eula)
530+
container_def = update_container_with_inference_params(
531+
framework=framework,
532+
framework_version=framework_version,
533+
nearest_model_name=nearest_model_name,
534+
data_input_configuration=data_input_configuration,
535+
container_def=container_def,
536+
)
536537
else:
537538
container_def = {
538539
"Image": self.image_uri,
@@ -547,10 +548,6 @@ def register(
547548
if self.model_data is not None:
548549
container_def["ModelDataUrl"] = self.model_data
549550

550-
if self.model_type is JumpStartModelType.PROPRIETARY:
551-
source_uri = self.model_package_arn
552-
model_package_group_name = self.model_id
553-
554551
model_pkg_args = sagemaker.get_model_package_args(
555552
self.content_types,
556553
self.response_types,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def test_proprietary_jumpstart_model(setup):
294294

295295
assert response is not None
296296

297+
297298
@pytest.mark.skipif(
298299
True,
299300
reason="Only enable if test account is subscribed to the proprietary model",
@@ -309,7 +310,6 @@ def test_register_proprietary_jumpstart_model(setup):
309310
sagemaker_session=get_sm_session(),
310311
)
311312
model_package = model.register()
312-
313313

314314
predictor = model_package.deploy(
315315
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
@@ -329,7 +329,7 @@ def test_register_proprietary_jumpstart_model(setup):
329329
)
330330
def test_register_gated_jumpstart_model(setup):
331331

332-
model_id="meta-textgenerationneuron-llama-2-7b"
332+
model_id = "meta-textgenerationneuron-llama-2-7b"
333333
model = JumpStartModel(
334334
model_id=model_id,
335335
model_version="1.1.0",
@@ -339,7 +339,8 @@ def test_register_gated_jumpstart_model(setup):
339339
model_package = model.register(accept_eula=True)
340340

341341
predictor = model_package.deploy(
342-
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True
342+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
343+
accept_eula=True,
343344
)
344345
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
345346

@@ -348,4 +349,3 @@ def test_register_gated_jumpstart_model(setup):
348349
predictor.delete_predictor()
349350

350351
assert response is not None
351-

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,11 @@ def test_proprietary_model_endpoint(
513513
model.deploy()
514514

515515
mock_model_register.assert_called_once_with(
516+
model_type=JumpStartModelType.PROPRIETARY,
516517
content_types=["application/json"],
517518
response_types=["application/json"],
519+
model_package_group_name=model_id,
520+
source_uri=model.model_package_arn,
518521
)
519522

520523
mock_model_deploy.assert_called_once_with(
@@ -1416,8 +1419,10 @@ def test_model_registry_accept_and_response_types(
14161419
model.register()
14171420

14181421
mock_model_register.assert_called_once_with(
1422+
model_type=JumpStartModelType.OPEN_WEIGHTS,
14191423
content_types=["application/x-text"],
14201424
response_types=["application/json;verbose", "application/json"],
1425+
model_package_group_name=model.model_id,
14211426
)
14221427

14231428
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")

0 commit comments

Comments
 (0)