diff --git a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py index 92857434..936345be 100644 --- a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py @@ -15,7 +15,7 @@ import os import textwrap -import torch +import torch, torcheia from sagemaker_inference import ( content_types, decoder, @@ -28,6 +28,9 @@ INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT" DEFAULT_MODEL_FILENAME = "model.pt" +torch._C._jit_set_profiling_executor(False) +device = torch.device("cpu") + class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler): VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY) @@ -47,7 +50,11 @@ def default_model_fn(self, model_dir): raise FileNotFoundError("Failed to load model with default model_fn: missing file {}." .format(DEFAULT_MODEL_FILENAME)) # Client-framework is CPU only. But model will run in Elastic Inference server with CUDA. - return torch.jit.load(model_path, map_location=torch.device('cpu')) + model = torch.jit.load(model_path, map_location=torch.device('cpu')) + # attach_eia() is introduced in PyTorch Elastic Inference 1.5.1 + # by default attach to the 0th device + model = torcheia.jit.attach_eia(model, 0) + return model else: raise NotImplementedError(textwrap.dedent(""" Please provide a model_fn implementation. @@ -86,8 +93,8 @@ def default_predict_fn(self, data, model): model = model.to(device) input_data = data.to(device) model.eval() - with torch.jit.optimized_execution(True, {"target_device": "eia:0"}): - output = model(input_data) + with torch.jit.optimized_execution(True): + output = model.forward(input_data) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device)