Skip to content
103 changes: 100 additions & 3 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from __future__ import absolute_import

from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any
import pandas as pd
from botocore.exceptions import ClientError

from sagemaker import payloads
Expand All @@ -36,14 +37,26 @@
get_init_kwargs,
get_register_kwargs,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.types import (
JumpStartSerializablePayload,
DeploymentConfigMetadata,
JumpStartBenchmarkStat,
JumpStartMetadataConfig
)
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
verify_model_region_and_return_specs,
get_jumpstart_configs,
extract_metrics_from_deployment_configs,
)
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.utils import stringify_object, format_tags, Tags
from sagemaker.utils import (
stringify_object,
format_tags,
Tags,
get_instance_rate_per_hour
)
from sagemaker.model import (
Model,
ModelPackage,
Expand Down Expand Up @@ -281,6 +294,10 @@ def __init__(
ValueError: If the model ID is not recognized by JumpStart.
"""

self._deployment_configs = None
self._metadata_configs = None
self._benchmark_metrics = None

def _validate_model_id_and_type():
return validate_model_id_and_get_type(
model_id=model_id,
Expand Down Expand Up @@ -786,6 +803,86 @@ def register_deploy_wrapper(*args, **kwargs):

return model_package

@property
def benchmark_metrics(self) -> pd.DataFrame:
"""Pandas DataFrame object of Benchmark Metrics for deployment configs"""
if self._benchmark_metrics:
return self._benchmark_metrics

data = extract_metrics_from_deployment_configs(
deployment_configs=self.list_deployment_configs(),
config_name=self.config_name,
)

self._benchmark_metrics = pd.DataFrame(data)
return self._benchmark_metrics

def display_benchmark_metrics(self):
"""Display Benchmark Metrics for deployment configs."""
print(self.benchmark_metrics.to_markdown())

def list_deployment_configs(self) -> List[Dict[str, Any]]:
"""List deployment configs for ``This`` model in the current region.
Returns:
A list of deployment configs (List[Dict[str, Any]]).
"""
if self._metadata_configs is None:
self._metadata_configs = get_jumpstart_configs(
region=self.region,
model_id=self.model_id,
model_version=self.model_version,
sagemaker_session=self.sagemaker_session,
)

self._deployment_configs = [
self._convert_to_deployment_config_metadata(config_name, config)
for config_name, config in self._metadata_configs.items()
]

return self._deployment_configs

def _convert_to_deployment_config_metadata(
self, config_name: str, metadata_config: JumpStartMetadataConfig
) -> Dict[str, Any]:
"""Retrieve deployment config for config name.
Args:
config_name (str): Name of deployment config.
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.
Returns:
A deployment metadata config for config name (dict[str, Any]).
"""
default_inference_instance_type = metadata_config.resolved_config.get(
"default_inference_instance_type"
)

instance_rate = get_instance_rate_per_hour(
instance_type=default_inference_instance_type, region=self.region
)

benchmark_metrics = metadata_config.benchmark_metrics.get(default_inference_instance_type)
if instance_rate is not None:
if benchmark_metrics is not None:
benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate))
else:
benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)]

init_kwargs = get_init_kwargs(
model_id=self.model_id,
instance_type=default_inference_instance_type,
sagemaker_session=self.sagemaker_session,
)
deploy_kwargs = get_deploy_kwargs(
model_id=self.model_id,
instance_type=default_inference_instance_type,
sagemaker_session=self.sagemaker_session,
)

deployment_config_metadata = DeploymentConfigMetadata(
config_name, benchmark_metrics, init_kwargs, deploy_kwargs
)

return deployment_config_metadata.to_json()

def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
return stringify_object(self)
90 changes: 90 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,3 +2154,93 @@ def __init__(
self.data_input_configuration = data_input_configuration
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri


class BaseDeploymentConfigDataHolder(JumpStartDataHolderType):
"""Base class for Deployment Config Data."""

def _convert_to_pascal_case(self, attr_name: str) -> str:
"""Converts a snake_case attribute name into a camelCased string."""
return attr_name.replace("_", " ").title().replace(" ", "")

def to_json(self) -> Dict[str, Any]:
"""Represents ``This`` object as JSON."""
json_obj = {}
for att in self.__slots__:
if hasattr(self, att):
cur_val = getattr(self, att)
att = self._convert_to_pascal_case(att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
elif isinstance(cur_val, list):
json_obj[att] = []
for obj in cur_val:
if issubclass(type(obj), JumpStartDataHolderType):
Copy link
Member

@evakravi evakravi Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic's really complicated. if you can find a way to reduce indentation level, that'd improve readability

json_obj[att].append(obj.to_json())
else:
json_obj[att].append(obj)
elif isinstance(cur_val, dict):
json_obj[att] = {}
for key, val in cur_val.items():
if issubclass(type(val), JumpStartDataHolderType):
json_obj[att][self._convert_to_pascal_case(key)] = val.to_json()
else:
json_obj[att][key] = val
else:
json_obj[att] = cur_val
return json_obj


class DeploymentConfig(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config."""

__slots__ = [
"model_data_download_timeout",
"container_startup_health_check_timeout",
"image_uri",
"model_data",
"instance_type",
"environment",
"compute_resource_requirements",
]

def __init__(
self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs
):
"""Instantiates DeploymentConfig object."""
if init_kwargs is not None:
self.image_uri = init_kwargs.image_uri
self.model_data = init_kwargs.model_data
self.instance_type = init_kwargs.instance_type
self.environment = init_kwargs.env
if init_kwargs.resources is not None:
self.compute_resource_requirements = (
init_kwargs.resources.get_compute_resource_requirements()
)
if deploy_kwargs is not None:
self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout
self.container_startup_health_check_timeout = (
deploy_kwargs.container_startup_health_check_timeout
)


class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config Metadata"""

__slots__ = [
"config_name",
"benchmark_metrics",
"deployment_config",
]

def __init__(
self,
config_name: str,
benchmark_metrics: List[JumpStartBenchmarkStat],
init_kwargs: JumpStartModelInitKwargs,
deploy_kwargs: JumpStartModelDeployKwargs,
):
"""Instantiates DeploymentConfigMetadata object."""
self.config_name = config_name
self.benchmark_metrics = benchmark_metrics
self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs)
40 changes: 40 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,43 @@ def get_jumpstart_configs(
if metadata_configs
else {}
)


def extract_metrics_from_deployment_configs(
deployment_configs: list[dict[str, Any]], config_name: str
) -> Dict[str, List[str]]:
"""Extracts metrics from deployment configs.
Args:
deployment_configs (list[dict[str, Any]]): List of deployment configs.
config_name (str): The name of the deployment config use by the model.
"""

data = {"Config Name": [], "Instance Type": [], "Selected": []}

for index, deployment_config in enumerate(deployment_configs):
if deployment_config.get("DeploymentConfig") is None:
continue

benchmark_metrics = deployment_config.get("BenchmarkMetrics")
if benchmark_metrics is not None:
data["Config Name"].append(deployment_config.get("ConfigName"))
data["Instance Type"].append(
deployment_config.get("DeploymentConfig").get("InstanceType")
)
data["Selected"].append(
"Yes"
if (config_name is not None and config_name == deployment_config.get("ConfigName"))
else "No"
)

if index == 0:
for benchmark_metric in benchmark_metrics:
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
data[column_name] = []

for benchmark_metric in benchmark_metrics:
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
if column_name in data.keys():
data[column_name].append(benchmark_metric.get("value"))

return data
13 changes: 12 additions & 1 deletion src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type
from typing import Type, Any
import logging

from sagemaker.model import Model
Expand Down Expand Up @@ -431,6 +431,17 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)

def display_benchmark_metrics(self):
"""Display Markdown Benchmark Metrics for deployment configs."""
self.pysdk_model.display_benchmark_metrics()

def list_deployment_configs(self) -> list[dict[str, Any]]:
"""List deployment configs for ``This`` model in the current region.
Returns:
A list of deployment configs (List[Dict[str, Any]]).
"""
return self.pysdk_model.list_deployment_configs()

def _build_for_jumpstart(self):
"""Placeholder docstring"""
# we do not pickle for jumpstart. set to none
Expand Down
53 changes: 53 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from os.path import abspath, realpath, dirname, normpath, join as joinpath

from importlib import import_module

import boto3
import botocore
from botocore.utils import merge_dicts
from six.moves.urllib import parse
Expand Down Expand Up @@ -1655,3 +1657,54 @@ def deep_override_dict(
)
flattened_dict1.update(flattened_dict2)
return unflatten_dict(flattened_dict1) if flattened_dict1 else {}


def get_instance_rate_per_hour(
instance_type: str,
region: str,
pricing_client: boto3.client = boto3.client("pricing", region_name="us-east-1"),
) -> Union[Dict[str, str], None]:
"""Gets instance rate per hour for the given instance type.
Args:
instance_type (str): The instance type.
region (str): The region.
pricing_client (OPTIONAL[boto3.client]): The pricing client.
Returns:
Union[Dict[str, str], None]: Instance rate per hour.
Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.1250000000'}}.
"""

instance_rate = None
try:
res = pricing_client.get_products(
ServiceCode="AmazonSageMaker",
Filters=[
{"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type},
{"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"},
{"Type": "TERM_MATCH", "Field": "regionCode", "Value": region},
],
)

price_list = res.get("PriceList", [])
if len(price_list) > 0:
price_data = price_list[0]
if isinstance(price_data, str):
price_data = json.loads(price_data)

price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values()
for dimension in price_dimensions:
for price in dimension.get("priceDimensions", {}).values():
for currency in price.get("pricePerUnit", {}).keys():
instance_rate = {
"unit": f"{currency}/{price.get('unit', 'Hrs')}",
"value": price.get("pricePerUnit", {}).get(currency),
"name": "Instance Rate",
}
break
break
break
except Exception as e: # pylint: disable=W0703
logging.exception("Error getting instance rate: %s", e)
return None

return instance_rate
Loading