Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/sagemaker_pytorch_serving_container/torchserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

PYTHON_PATH_ENV = "PYTHONPATH"
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
LOG4J_OVERRIDE_PATH = os.path.join(code_dir, "log4j.xml")
TS_NAMESPACE = "org.pytorch.serve.ModelServer"


Expand Down Expand Up @@ -81,6 +82,11 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
if os.path.exists(REQUIREMENTS_PATH):
_install_requirements()

if os.path.exists(LOG4J_OVERRIDE_PATH):
log4j_path = LOG4J_OVERRIDE_PATH
else:
log4j_path = DEFAULT_TS_LOG_FILE

ts_torchserve_cmd = [
"torchserve",
"--start",
Expand All @@ -89,7 +95,7 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
"--ts-config",
TS_CONFIG_FILE,
"--log-config",
DEFAULT_TS_LOG_FILE,
log4j_path,
"--models",
"model.mar"
]
Expand Down
14 changes: 9 additions & 5 deletions test/unit/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

from sagemaker_inference import environment
from sagemaker_pytorch_serving_container import torchserve
from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE, REQUIREMENTS_PATH
from sagemaker_pytorch_serving_container.torchserve import (
TS_NAMESPACE, REQUIREMENTS_PATH, LOG4J_OVERRIDE_PATH
)

PYTHON_PATH = "python_path"
DEFAULT_CONFIGURATION = "default_configuration"
Expand All @@ -32,7 +34,7 @@
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
@patch("os.path.exists", return_value=True)
@patch("os.path.exists", side_effect=[True, False])
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
def test_start_torchserve_default_service_handler(
Expand All @@ -49,7 +51,8 @@ def test_start_torchserve_default_service_handler(

adapt.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE)
create_config.assert_called_once_with()
exists.assert_called_once_with(REQUIREMENTS_PATH)
exists.assert_any_call(REQUIREMENTS_PATH)
exists.assert_any_call(LOG4J_OVERRIDE_PATH)
install_requirements.assert_called_once_with()

ts_model_server_cmd = [
Expand All @@ -74,7 +77,7 @@ def test_start_torchserve_default_service_handler(
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
@patch("os.path.exists", return_value=True)
@patch("os.path.exists", side_effect=[True, False])
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
@patch("sagemaker_pytorch_serving_container.torchserve._adapt_to_ts_format")
def test_start_torchserve_default_service_handler_multi_model(
Expand All @@ -91,7 +94,8 @@ def test_start_torchserve_default_service_handler_multi_model(
torchserve.start_torchserve()
torchserve.ENABLE_MULTI_MODEL = False
create_config.assert_called_once_with()
exists.assert_called_once_with(REQUIREMENTS_PATH)
exists.assert_any_call(REQUIREMENTS_PATH)
exists.assert_any_call(LOG4J_OVERRIDE_PATH)
install_requirements.assert_called_once_with()

ts_model_server_cmd = [
Expand Down