diff --git a/src/usr/lib/python3/dist-packages/ztp/Downloader.py b/src/usr/lib/python3/dist-packages/ztp/Downloader.py index 7f3348d..8f16c22 100644 --- a/src/usr/lib/python3/dist-packages/ztp/Downloader.py +++ b/src/usr/lib/python3/dist-packages/ztp/Downloader.py @@ -17,6 +17,7 @@ import sys import os +import shlex import stat import time @@ -204,22 +205,26 @@ def getUrl(self, url=None, dst_file=None, incl_http_headers=None, is_secure=True return (20, None) # Create curl command - cmd = '/usr/bin/curl -f -v -s -o ' + dst_file + cmd = ['/usr/bin/curl', '-f', '-v', '-s', '-o', dst_file] if self.__user_agent is not None: - cmd += ' -A "' + self.__user_agent + '"' # --user-agent + cmd += ['-A', self.__user_agent] # --user-agent if is_secure is False: - cmd += ' -k' # --insecure + cmd += ['-k'] # --insecure if timeout is not None and isinstance(timeout, int) is True: - cmd += ' --connect-timeout ' + str(timeout) + cmd += ['--connect-timeout', str(timeout)] if retry is not None and isinstance(retry, int) is True: - cmd += ' --retry ' + str(retry) + cmd += ['--retry', str(retry)] if incl_http_headers is not None: for h in self.__http_headers: - cmd += ' -H \"' + h + '"' # --header + cmd += ['-H', h] # --header if curl_args is not None: - cmd += ' ' + curl_args - cmd += ' ' + url + try: + cmd += shlex.split(curl_args) + except ValueError as e: + logger.error('Invalid curl_args value: %s' % str(e)) + return (1, None) + cmd += ['--', url] if verbose is True: logger.debug('%s' % (cmd)) diff --git a/src/usr/lib/ztp/dhcp/ztp b/src/usr/lib/ztp/dhcp/ztp index b2d452c..54f2893 100755 --- a/src/usr/lib/ztp/dhcp/ztp +++ b/src/usr/lib/ztp/dhcp/ztp @@ -103,29 +103,29 @@ fi if [ "$(ztp status -c)" != "0:DISABLED" ]; then if [ -n "$new_bootfile_name" ]; then - take_lock dhcp && echo $new_bootfile_name > $ZTP_JSON_URL_FILE + take_lock dhcp && echo "$new_bootfile_name" > $ZTP_JSON_URL_FILE if [ -n "$new_tftp_server_name" ]; then - take_lock dhcp && echo $new_tftp_server_name > $ZTP_TFTP_SERVER_FILE + take_lock dhcp && printf '%s\n' "$new_tftp_server_name" > $ZTP_TFTP_SERVER_FILE fi fi if [ -n "$new_dhcp6_boot_file_url" ]; then - take_lock dhcp6 && echo $new_dhcp6_boot_file_url > $ZTP_JSON_URL6_FILE + take_lock dhcp6 && printf '%s\n' "$new_dhcp6_boot_file_url" > $ZTP_JSON_URL6_FILE fi if [ -n "$new_provisioning_script_url" ]; then - take_lock dhcp && echo $new_provisioning_script_url > $PROVISIONING_SCRIPT_URL_FILE + take_lock dhcp && printf '%s\n' "$new_provisioning_script_url" > $PROVISIONING_SCRIPT_URL_FILE fi if [ -n "$new_dhcp6_provisioning_script_url" ]; then - take_lock dhcp6 && echo $new_dhcp6_provisioning_script_url > $PROVISIONING_SCRIPT_URL6_FILE + take_lock dhcp6 && printf '%s\n' "$new_dhcp6_provisioning_script_url" > $PROVISIONING_SCRIPT_URL6_FILE fi if [ -n "$new_minigraph_url" ]; then - take_lock dhcp && echo $new_minigraph_url > ${GRAPH_URL} + take_lock dhcp && printf '%s\n' "$new_minigraph_url" > ${GRAPH_URL} if [ -n "$new_acl_url" ]; then - take_lock dhcp && echo $new_acl_url > ${ACL_URL} + take_lock dhcp && printf '%s\n' "$new_acl_url" > ${ACL_URL} fi fi fi diff --git a/src/usr/lib/ztp/ztp-engine.py b/src/usr/lib/ztp/ztp-engine.py index f1e7bad..0dc9166 100755 --- a/src/usr/lib/ztp/ztp-engine.py +++ b/src/usr/lib/ztp/ztp-engine.py @@ -696,6 +696,10 @@ def __downloadURL(self, url_file, dst_file, url_prefix=None): url_str = f.readline().strip() f.close() + if ' ' in url_str or '\t' in url_str: + logger.error('Failed to download provided URL %s, URL contains whitespace.' % (url_str)) + return False + res = urlparse(url_str) if res is None or res.scheme == '': # Use passed url_prefix to construct final URL diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ccfa8c8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +""" +conftest.py — test environment setup for sonic-ztp unit tests. + +Creates temporary directories to replace SONiC-specific system paths +(/host/ztp, /etc/rsyslog.d) so tests can run on a bare host without +a SONiC container. +""" + +import os +import sys +import tempfile +import pytest + +# --------------------------------------------------------------------------- +# Bootstrap: patch system paths BEFORE any ztp module is imported. +# ZTPCfg.py and Logger.py run module-level code at import time that touches +# /host/ztp and /etc/rsyslog.d, so the patches must be in place first. +# --------------------------------------------------------------------------- + +# Create a single persistent temp dir for the whole test session. +_tmp_root = tempfile.mkdtemp(prefix="ztp_test_") +_fake_host_ztp = os.path.join(_tmp_root, "host", "ztp") +_fake_rsyslog_d = os.path.join(_tmp_root, "etc", "rsyslog.d") +_fake_sonic_dir = os.path.join(_tmp_root, "etc", "sonic") + +os.makedirs(_fake_host_ztp, exist_ok=True) +os.makedirs(_fake_rsyslog_d, exist_ok=True) +os.makedirs(_fake_sonic_dir, exist_ok=True) + +# Add ztp package to path so `from ztp.X import Y` works. +_ztp_pkg_dir = os.path.join( + os.path.dirname(__file__), + "..", "src", "usr", "lib", "python3", "dist-packages" +) +if _ztp_pkg_dir not in sys.path: + sys.path.insert(0, os.path.abspath(_ztp_pkg_dir)) + +# Patch defaults BEFORE importing any ztp submodule. +import ztp.defaults as _defaults + +_defaults.cfg_file = os.path.join(_fake_host_ztp, "ztp_cfg.json") +_defaults.defaultCfg["ztp-cfg-dir"] = _fake_host_ztp +_defaults.defaultCfg["ztp-json"] = os.path.join(_fake_host_ztp, "ztp_data.json") +_defaults.defaultCfg["ztp-json-shadow"] = os.path.join(_fake_host_ztp, "ztp_data_shadow.json") +_defaults.defaultCfg["ztp-json-local"] = os.path.join(_fake_host_ztp, "ztp_data_local.json") +_defaults.defaultCfg["provisioning-script"] = os.path.join(_fake_host_ztp, "provisioning-script") +_defaults.defaultCfg["rsyslog-ztp-log-file-conf"] = os.path.join(_fake_rsyslog_d, "10-ztp-log-file.conf") +_defaults.defaultCfg["rsyslog-ztp-consile-log-file-conf"] = os.path.join(_fake_rsyslog_d, "10-ztp-console-logging.conf") +_defaults.defaultCfg["log-file"] = os.path.join(_tmp_root, "ztp.log") +_defaults.defaultCfg["ztp-tmp"] = os.path.join(_tmp_root, "tmp") + +os.makedirs(_defaults.defaultCfg["ztp-tmp"], exist_ok=True) diff --git a/tests/test_Downloader_input_validation.py b/tests/test_Downloader_input_validation.py new file mode 100644 index 0000000..3ce0448 --- /dev/null +++ b/tests/test_Downloader_input_validation.py @@ -0,0 +1,186 @@ +''' +Input validation tests for curl command construction in Downloader.py + +Tests verify that url, dst_file, and curl_args are handled correctly +when DHCP-supplied values are used. runCommand is mocked so no real curl +or network is required. +''' + +import os +import sys +import pytest +from unittest.mock import patch, MagicMock + +from ztp.Downloader import Downloader + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _make_downloader(**kwargs): + """Return a Downloader with safe defaults for unit testing.""" + return Downloader( + is_secure=False, + timeout=30, + retry=0, + incl_http_headers=False, + **kwargs, + ) + + +def _capture_cmd(tmp_path, url, **dl_kwargs): + """ + Call Downloader.getUrl() with a mocked runCommand that records the argv + list it receives. Returns the captured cmd list. + """ + dst = str(tmp_path / "out.txt") + captured = {} + + def fake_run(cmd, **kwargs): + captured['cmd'] = list(cmd) + # Simulate a successful curl: create the output file. + open(dst, 'w').close() + return (0, [], []) + + dn = _make_downloader(**dl_kwargs) + with patch('ztp.Downloader.runCommand', side_effect=fake_run): + rc, fname = dn.getUrl(url, dst_file=dst) + + return captured.get('cmd', []) + + +# --------------------------------------------------------------------------- +# URL handling tests +# --------------------------------------------------------------------------- + +class TestUrlHandling: + """Verify that spaces/flags in the URL do not become extra curl arguments.""" + + def test_url_with_space_is_single_arg(self, tmp_path): + """A URL containing a space must not split into multiple curl tokens.""" + url = 'http://example.com/file --output /tmp/evil' + cmd = _capture_cmd(tmp_path, url) + # The injected flag must not appear as a standalone token. + assert '--output' not in cmd + # The full URL string must appear as one argument (after '--'). + assert url in cmd + + def test_url_with_flag_after_dashdash(self, tmp_path): + """'--' must appear before the URL so curl treats it as a positional arg.""" + url = 'http://example.com/file' + cmd = _capture_cmd(tmp_path, url) + assert '--' in cmd + dashdash_idx = cmd.index('--') + url_idx = cmd.index(url) + assert url_idx == dashdash_idx + 1, "'--' must immediately precede the URL" + + def test_url_with_leading_dash(self, tmp_path): + """A URL that looks like a flag (leading dash) must be safe. + + The URL string '-o /tmp/evil http://example.com/' must appear as a + single positional argument after '--', not be split so that '-o' + injects an extra output path. + """ + url = '-o /tmp/evil http://example.com/' + cmd = _capture_cmd(tmp_path, url) + # '--' must exist and the full URL string must appear after it. + assert '--' in cmd + dashdash_idx = cmd.index('--') + url_idx = cmd.index(url) + assert url_idx > dashdash_idx, \ + f"URL must appear after '--', but cmd={cmd}" + # '/tmp/evil' must NOT appear as a standalone token (injected path). + assert '/tmp/evil' not in cmd, \ + f"Injected path '/tmp/evil' appeared as a standalone token: {cmd}" + + def test_url_with_config_flag(self, tmp_path): + """'--config' in URL must not reach curl as a real flag.""" + url = 'http://example.com/ --config /tmp/evil.conf' + cmd = _capture_cmd(tmp_path, url) + assert '--config' not in cmd[:cmd.index('--')] + assert url in cmd + + +# --------------------------------------------------------------------------- +# curl_args handling tests +# --------------------------------------------------------------------------- + +class TestCurlArgsHandling: + """Verify that curl_args is split safely and merged into the argv list.""" + + def test_curl_args_legitimate(self, tmp_path): + """Legitimate curl_args like '--max-time 5' must be forwarded.""" + url = 'http://example.com/file' + cmd = _capture_cmd(tmp_path, url, curl_args='--max-time 5') + assert '--max-time' in cmd + assert '5' in cmd + + def test_curl_args_multiple_flags(self, tmp_path): + """Multiple legitimate curl_args must all appear.""" + url = 'http://example.com/file' + cmd = _capture_cmd(tmp_path, url, curl_args='--compressed --max-time 10') + assert '--compressed' in cmd + assert '--max-time' in cmd + assert '10' in cmd + + def test_curl_args_does_not_override_url(self, tmp_path): + """curl_args must be inserted before '--', not after.""" + url = 'http://example.com/file' + cmd = _capture_cmd(tmp_path, url, curl_args='--max-time 5') + dashdash_idx = cmd.index('--') + url_idx = cmd.index(url) + # '--' must still immediately precede url + assert url_idx == dashdash_idx + 1 + + def test_cmd_is_list_not_string(self, tmp_path): + """runCommand must receive a list, never a string.""" + url = 'http://example.com/file' + dst = str(tmp_path / "out.txt") + captured = {} + + def fake_run(cmd, **kwargs): + captured['type'] = type(cmd) + captured['cmd'] = cmd + open(dst, 'w').close() + return (0, [], []) + + dn = _make_downloader() + with patch('ztp.Downloader.runCommand', side_effect=fake_run): + dn.getUrl(url, dst_file=dst) + + assert captured['type'] is list, \ + f"runCommand must receive a list, got {captured['type']}" + + +# --------------------------------------------------------------------------- +# dst_file handling test +# --------------------------------------------------------------------------- + +class TestDstFileHandling: + """Verify that dst_file value is treated as a single path, not shell tokens.""" + + def test_dst_file_with_space_is_single_arg(self, tmp_path): + """dst_file with a space must be passed as one -o argument.""" + # Use a path with a space in the directory name + spacedir = tmp_path / "my dir" + spacedir.mkdir() + dst = str(spacedir / "out.txt") + url = 'http://example.com/file' + captured = {} + + def fake_run(cmd, **kwargs): + captured['cmd'] = list(cmd) + open(dst, 'w').close() + return (0, [], []) + + dn = _make_downloader() + with patch('ztp.Downloader.runCommand', side_effect=fake_run): + rc, fname = dn.getUrl(url, dst_file=dst) + + cmd = captured.get('cmd', []) + # '-o' must be followed by the full dst path as one token + assert '-o' in cmd + o_idx = cmd.index('-o') + assert cmd[o_idx + 1] == dst, \ + f"Expected dst as single token after -o, got: {cmd[o_idx+1:]}"