Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/sagemaker_pytorch_serving_container/torchserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

PYTHON_PATH_ENV = "PYTHONPATH"
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
LOG4J_OVERRIDE_PATH = os.path.join(code_dir, "log4j.properties")
TS_NAMESPACE = "org.pytorch.serve.ModelServer"


Expand Down Expand Up @@ -81,6 +82,11 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
if os.path.exists(REQUIREMENTS_PATH):
_install_requirements()

if os.path.exists(LOG4J_OVERRIDE_PATH):
log4j_path = LOG4J_OVERRIDE_PATH
else:
log4j_path = DEFAULT_TS_LOG_FILE

ts_torchserve_cmd = [
"torchserve",
"--start",
Expand All @@ -89,7 +95,7 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
"--ts-config",
TS_CONFIG_FILE,
"--log-config",
DEFAULT_TS_LOG_FILE,
log4j_path,
"--models",
"model.mar"
]
Expand Down
14 changes: 9 additions & 5 deletions test/unit/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

from sagemaker_inference import environment
from sagemaker_pytorch_serving_container import torchserve
from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE, REQUIREMENTS_PATH
from sagemaker_pytorch_serving_container.torchserve import (
TS_NAMESPACE, REQUIREMENTS_PATH, LOG4J_OVERRIDE_PATH
)

PYTHON_PATH = "python_path"
DEFAULT_CONFIGURATION = "default_configuration"
Expand All @@ -32,7 +34,7 @@
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
@patch("os.path.exists", return_value=True)
@patch("os.path.exists", side_effect=[True, False])
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
def test_start_torchserve_default_service_handler(
Expand All @@ -49,7 +51,8 @@ def test_start_torchserve_default_service_handler(

adapt.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE)
create_config.assert_called_once_with()
exists.assert_called_once_with(REQUIREMENTS_PATH)
exists.assert_any_call(REQUIREMENTS_PATH)
exists.assert_any_call(LOG4J_OVERRIDE_PATH)
install_requirements.assert_called_once_with()

ts_model_server_cmd = [
Expand All @@ -74,7 +77,7 @@ def test_start_torchserve_default_service_handler(
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
@patch("os.path.exists", return_value=True)
@patch("os.path.exists", side_effect=[True, False])
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
def test_start_torchserve_default_service_handler_multi_model(
Expand All @@ -91,7 +94,8 @@ def test_start_torchserve_default_service_handler_multi_model(
torchserve.start_torchserve()
torchserve.ENABLE_MULTI_MODEL = False
create_config.assert_called_once_with()
exists.assert_called_once_with(REQUIREMENTS_PATH)
exists.assert_any_call(REQUIREMENTS_PATH)
exists.assert_any_call(LOG4J_OVERRIDE_PATH)
install_requirements.assert_called_once_with()

ts_model_server_cmd = [
Expand Down