From b3996563cc8b62f865a1a8f53b63823a10a7bf45 Mon Sep 17 00:00:00 2001 From: aubin bikouo Date: Mon, 10 Mar 2025 18:14:41 +0100 Subject: [PATCH 1/5] connection/aws_ssm - use port forwarding session combined with `nc` for file transfer --- ...s_ssm-support-netcat-for-file-transfer.yml | 3 + plugins/connection/aws_ssm.py | 81 ++++- plugins/plugin_utils/ssm_file_transfer.py | 307 ++++++++++++++++++ .../connection_aws_ssm_fedora_netcat/aliases | 4 + .../aws_ssm_integration_test_setup.yml | 7 + .../aws_ssm_integration_test_teardown.yml | 5 + .../meta/main.yml | 2 + .../connection_aws_ssm_fedora_netcat/runme.sh | 27 ++ .../connection_aws_ssm_fedora_netcat/test.yml | 45 +++ .../defaults/main.yml | 4 +- .../setup_connection_aws_ssm/tasks/main.yml | 37 +-- .../tasks/ssm_document.yml | 13 +- .../templates/inventory-combined.aws_ssm.j2 | 5 + .../connection/aws_ssm/test_aws_ssm.py | 6 + 14 files changed, 502 insertions(+), 44 deletions(-) create mode 100644 changelogs/fragments/20250204-aws_ssm-support-netcat-for-file-transfer.yml create mode 100644 plugins/plugin_utils/ssm_file_transfer.py create mode 100644 tests/integration/targets/connection_aws_ssm_fedora_netcat/aliases create mode 100644 tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_setup.yml create mode 100644 tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_teardown.yml create mode 100644 tests/integration/targets/connection_aws_ssm_fedora_netcat/meta/main.yml create mode 100755 tests/integration/targets/connection_aws_ssm_fedora_netcat/runme.sh create mode 100644 tests/integration/targets/connection_aws_ssm_fedora_netcat/test.yml diff --git a/changelogs/fragments/20250204-aws_ssm-support-netcat-for-file-transfer.yml b/changelogs/fragments/20250204-aws_ssm-support-netcat-for-file-transfer.yml new file mode 100644 index 00000000000..cdcf3f470f2 --- /dev/null +++ b/changelogs/fragments/20250204-aws_ssm-support-netcat-for-file-transfer.yml @@ -0,0 +1,3 @@ +--- +minor_changes: + - aws_ssm - Add for file transfer using SSM port forwarding session with netcat for Linux/MacOS EC2 managed nodes (https://github.com/ansible-collections/community.aws/pull/2265). diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index f39e480d3f0..5bf28340098 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -86,6 +86,29 @@ vars: - name: ansible_aws_ssm_bucket_endpoint_url version_added: 5.3.0 + host_port_number: + description: + - The Port number of the server on the instance when using Port Forwarding Using AWS System Manager Session Manager + to transfer files from/to local host to/from remote host. + - The port V(80) is used if not provided. + - The C(nc) command should be installed in the remote host to use this option. + - This is not supported for Windows hosts for now. + type: integer + default: 80 + vars: + - name: ansible_aws_ssm_host_port_number + version_added: 10.0.0 + local_port_number: + description: + - Port number on local machine to forward traffic to when using Port Forwarding Using AWS System Manager Session Manager + to transfer files from/to local host to/from remote host. + - An open port is chosen at run-time if not provided. + - The C(nc) command should be installed in the remote host to use this option. + - This is not supported for Windows hosts for now. + type: integer + vars: + - name: ansible_aws_ssm_local_port_number + version_added: 10.0.0 plugin: description: - This defines the location of the session-manager-plugin binary. @@ -360,6 +383,7 @@ from ansible.utils.display import Display from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 +from ansible_collections.community.aws.plugins.plugin_utils.ssm_file_transfer import PortForwardingFileTransferManager from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager @@ -472,6 +496,7 @@ class Connection(ConnectionBase): _stdout = None _session_id = "" _timeout = False + _filetransfer_mgr = None MARK_LENGTH = 26 def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -516,19 +541,39 @@ def _init_clients(self) -> None: profile_name = self.get_option("profile") or "" region_name = self.get_option("region") - # Initialize S3ClientManager - self.s3_manager = S3ClientManager(self) - - # Initialize S3 client - s3_endpoint_url, s3_region_name = self.s3_manager.get_bucket_endpoint() - self.verbosity_display(4, f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}") - self.s3_manager.initialize_client( - region_name=s3_region_name, endpoint_url=s3_endpoint_url, profile_name=profile_name - ) - self._s3_client = self.s3_manager._s3_client - # Initialize SSM client self._initialize_ssm_client(region_name, profile_name) + if self._use_bucket(): + # Initialize S3ClientManager + self.s3_manager = S3ClientManager(self) + + # Initialize S3 client + s3_endpoint_url, s3_region_name = self.s3_manager.get_bucket_endpoint() + self.verbosity_display(4, f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}") + self.s3_manager.initialize_client( + region_name=s3_region_name, endpoint_url=s3_endpoint_url, profile_name=profile_name + ) + self._s3_client = self.s3_manager._s3_client + else: + self._initialize_file_transfer_manager() + + def _initialize_file_transfer_manager(self) -> None: + ssm_timeout = self.get_option("ssm_timeout") + region_name = self.get_option("region") + profile_name = self.get_option("profile") or "" + host_port = self.get_option("host_port_number") + local_port = self.get_option("local_port_number") + self._filetransfer_mgr = PortForwardingFileTransferManager( + self.host, + ssm_client=self._client, + instance_id=self.instance_id, + executable=self.get_executable(), + ssm_timeout=ssm_timeout, + region_name=region_name, + profile_name=profile_name, + host_port=host_port, + local_port=local_port, + ) def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None: """ @@ -574,6 +619,10 @@ def reset(self) -> Any: self.close() return self.start_session() + def _use_bucket(self) -> bool: + """return true if the file transfer is performed using s3 bucket""" + return self.is_windows or self.get_option("bucket_name") + @property def instance_id(self) -> str: if not self._instance_id: @@ -1090,7 +1139,10 @@ def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]: if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")): raise AnsibleFileNotFound(f"file or module does not exist: {in_path}") - return self._file_transport_command(in_path, out_path, "put") + if self._use_bucket(): + return self._file_transport_command(in_path, out_path, "put") + else: + return self._filetransfer_mgr.put_file(in_path, out_path) def fetch_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]: """fetch a file from remote to local""" @@ -1098,7 +1150,10 @@ def fetch_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]: super().fetch_file(in_path, out_path) self.verbosity_display(3, f"FETCH {in_path} TO {out_path}") - return self._file_transport_command(in_path, out_path, "get") + if self._use_bucket(): + return self._file_transport_command(in_path, out_path, "get") + else: + return self._filetransfer_mgr.fetch_file(in_path, out_path) def close(self) -> None: """terminate the connection""" diff --git a/plugins/plugin_utils/ssm_file_transfer.py b/plugins/plugin_utils/ssm_file_transfer.py new file mode 100644 index 00000000000..0dc30eb16d7 --- /dev/null +++ b/plugins/plugin_utils/ssm_file_transfer.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import json +import os +import pty +import random +import re +import select +import socket +import string +import subprocess +import time +from typing import Any +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Union + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text +from ansible.utils.display import Display + +display = Display() + + +class ConnectionPluginDisplay: + def __init__(self, host: Optional[str]) -> None: + self._host_args = {} + if host: + self._host_args = {"host": host} + + def _display(self, f: callable, message: str) -> None: + f(to_text(message), **self._host_args) + + def v(self, message): + self._display(display.v, message) + + def vv(self, message): + self._display(display.vv, message) + + def vvv(self, message): + self._display(display.vvv, message) + + def vvvv(self, message): + self._display(display.vvvv, message) + + +class StdoutPoller: + def __init__(self, session: Any, stdout: Any, poller: Any, timeout: int) -> None: + self._stdout = stdout + self._poller = poller + self._session = session + self._timeout = timeout + + def readline(self): + return self._stdout.readline() + + def has_data(self) -> bool: + return bool(self._poller.poll(self._timeout)) + + def read_stdout(self, length: int = 1024) -> str: + return self._stdout.read(length).decode("utf-8") + + def stdin_write(self, value: str) -> None: + self._session.stdin.write(value) + + def poll(self) -> NoReturn: + start = round(time.time()) + yield self.has_data() + while self._session.poll() is None: + remaining = start + self._timeout - round(time.time()) + if remaining < 0: + raise AnsibleConnectionFailure("StdoutPoller timeout...") + yield self.has_data() + + def match_expr(self, expr: Union[str, callable]) -> str: + time_start = time.time() + content = "" + while (int(time.time()) - time_start) < self._timeout: + if self.poll(): + content += self.read_stdout() + if callable(expr): + if expr(content): + return content + elif expr in content: + return content + raise TimeoutError(f"Unable to match expr '{expr}' from content") + + def flush_stderr(self) -> str: + """read and return stderr with minimal blocking""" + + poll_stderr = select.poll() + poll_stderr.register(self._session.stderr, select.POLLIN) + stderr = "" + while self._session.poll() is None: + if not poll_stderr.poll(1): + break + line = self._session.stderr.readline() + stderr = stderr + line + return stderr + + +class AnsibleAwsSSMSession: + def __init__( + self, + display: ConnectionPluginDisplay, + ssm_client: Any, + instance_id: str, + executable: str, + ssm_timeout: int, + region_name: Optional[str], + profile_name: str, + document_name: Optional[str] = None, + parameters: Optional[Dict[str, List[str]]] = None, + ) -> None: + self._client = ssm_client + self._session = None + self._session_id = None + self._local_port = None + self._display = display + + params = {"Target": instance_id} + if document_name: + params["DocumentName"] = document_name + if parameters: + params["Parameters"] = parameters + + try: + response = self._client.start_session(**params) + self._session_id = response["SessionId"] + self._display.vvvv(f"Start session - SessionId: {self._session_id}") + + cmd = [ + executable, + json.dumps(response), + region_name, + "StartSession", + profile_name, + json.dumps({"Target": instance_id}), + self._client.meta.endpoint_url, + ] + + stdout_r, stdout_w = pty.openpty() + self._session = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=stdout_w, + stderr=subprocess.PIPE, + close_fds=True, + bufsize=0, + ) + + os.close(stdout_w) + stdout = os.fdopen(stdout_r, "rb", 0) + self._poller = StdoutPoller( + session=self._session, + stdout=stdout, + poller=select.poll().register(stdout, select.POLLIN), + timeout=ssm_timeout, + ) + except Exception as e: + raise AnsibleConnectionFailure(f"failed to start session: {e}") + + def __del__(self): + if self._session_id: + self._display.vvvv(f"Terminating AWS Session: {self._session_id}") + self._client.terminate_session(SessionId=self._session_id) + if self._session: + self._display.vvvv("Terminating subprocess.Popen session") + self._session.terminate() + + +class PortForwardingFileTransferManager: + DEFAULT_HOST_PORT = 80 + + def __init__( + self, + host: Optional[str], + ssm_client: Any, + instance_id: str, + executable: str, + ssm_timeout: int, + region_name: Optional[str], + profile_name: str, + host_port: Optional[int], + local_port: Optional[int], + ) -> None: + self._client = ssm_client + self._session = None + self._instance_id = instance_id + self._session_id = None + self._poller = None + self._local_port = local_port + self._host_port = host_port or self.DEFAULT_HOST_PORT + self._display = ConnectionPluginDisplay(host=host) + + # Create Port forwarding Session + parameters = {} + if local_port: + parameters["localPortNumber"] = [str(local_port)] + if host_port: + parameters["portNumber"] = [str(host_port)] + self._portforwarding_session = AnsibleAwsSSMSession( + ssm_client=ssm_client, + instance_id=instance_id, + executable=executable, + ssm_timeout=ssm_timeout, + region_name=region_name, + profile_name=profile_name, + document_name="AWS-StartPortForwardingSession", + parameters=parameters, + display=self._display, + ) + + match_expr = re.compile(r"Port ([0-9]+) opened for sessionId") + content = self._portforwarding_session._poller.match_expr(expr=match_expr.search) + match = match_expr.search(content) + self._local_port = int(match.group(1)) + self._display.vvvv(f"SSM PORT FORWARDING - Local port '{self._local_port}'") + + # Start shell session + self._shell_session = AnsibleAwsSSMSession( + ssm_client=ssm_client, + instance_id=instance_id, + executable=executable, + ssm_timeout=ssm_timeout, + region_name=region_name, + profile_name=profile_name, + display=self._display, + ) + + def _socket_connect(self, session: Any, port: int, host: str = "localhost", max_attempts=10) -> None: + """Connect to socket""" + for attempt in range(max_attempts): + try: + session.connect((host, port)) + break + except OSError: + if attempt == max_attempts - 1: + self._display.vvvv(f"SOCKET _CONNECT: Failed to intiate socket connection on '{host}:{port}'") + raise + time.sleep(0.05) + + def _socket_read(self, port: int, out_path: Optional[str] = None) -> None: + self._display.vvvv(f"Read content from socket on port '{port}'...") + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as session: + session.settimeout(1) + self._socket_connect(session=session, port=port) + try: + with open(out_path, "wb") as fhandler: + while 1: + data = session.recv(1024) + if not data: + break + fhandler.write(data) + time.sleep(0.05) + except TimeoutError: + pass + self._display.vvvv("Socket connection closed.") + + def _socket_write(self, port: int, in_path: str) -> None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as session: + self._socket_connect(session=session, port=port) + with open(in_path, "rb") as f: + session.sendall(f.read()) + + def put_file(self, in_path: str, out_path: str) -> None: + # Start listener on Remote host + mark_end = "".join([random.choice(string.ascii_lowercase + string.digits) for i in range(12)]) + put_cmd = f"sudo nc -l -v -p {str(self._host_port)} > {out_path}; printf '{mark_end}'" + put_cmd = (put_cmd + "\n").encode() + self._display.vvvv(f"Write command '{put_cmd}'") + self._shell_session._poller.flush_stderr() + self._shell_session._poller.stdin_write(put_cmd) + # Ensure server is listening + self._shell_session._poller.match_expr(expr="Listening") + self._display.vvvv("Server is listening...") + + # Write data into socket + self._socket_write(port=self._local_port, in_path=in_path) + + # Ensure nc command has ended on remote host + self._shell_session._poller.match_expr(expr=mark_end) + + self._display.vvvv(f"End of polling, stderr = {self._shell_session._poller.flush_stderr()}") + + def fetch_file(self, in_path: str, out_path: str) -> None: + self._shell_session._poller.flush_stderr() + mark_end = "".join([random.choice(string.ascii_lowercase + string.digits) for i in range(12)]) + fetch_cmd = f"sudo nc -v -l {self._host_port} < {in_path}; printf '{mark_end}'" + fetch_cmd = (fetch_cmd + "\n").encode() + self._display.vvvv(f"Write command '{fetch_cmd}'") + self._shell_session._poller.stdin_write(fetch_cmd) + + # Ensure server is listening + self._shell_session._poller.match_expr(expr="Listening") + self._display.vvvv("Server is listening...") + # Read data from socket + self._socket_read(port=self._local_port, out_path=out_path) + # Ensure nc command has ended on remote host + self._shell_session._poller.match_expr(expr=mark_end) + + self._display.vvvv(f"End of polling, stderr = {self._shell_session._poller.flush_stderr()}") diff --git a/tests/integration/targets/connection_aws_ssm_fedora_netcat/aliases b/tests/integration/targets/connection_aws_ssm_fedora_netcat/aliases new file mode 100644 index 00000000000..f5af8799edb --- /dev/null +++ b/tests/integration/targets/connection_aws_ssm_fedora_netcat/aliases @@ -0,0 +1,4 @@ +time=8m + +cloud/aws +connection_aws_ssm diff --git a/tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_setup.yml b/tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_setup.yml new file mode 100644 index 00000000000..dac79d7ebf7 --- /dev/null +++ b/tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_setup.yml @@ -0,0 +1,7 @@ +- hosts: localhost + roles: + - role: ../setup_connection_aws_ssm + vars: + target_os: fedora + use_s3_bucket: false + host_port_number: 50547 diff --git a/tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_teardown.yml b/tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_teardown.yml new file mode 100644 index 00000000000..3ab6f74cf64 --- /dev/null +++ b/tests/integration/targets/connection_aws_ssm_fedora_netcat/aws_ssm_integration_test_teardown.yml @@ -0,0 +1,5 @@ +- hosts: localhost + tasks: + - include_role: + name: ../setup_connection_aws_ssm + tasks_from: cleanup.yml diff --git a/tests/integration/targets/connection_aws_ssm_fedora_netcat/meta/main.yml b/tests/integration/targets/connection_aws_ssm_fedora_netcat/meta/main.yml new file mode 100644 index 00000000000..d79e9b272be --- /dev/null +++ b/tests/integration/targets/connection_aws_ssm_fedora_netcat/meta/main.yml @@ -0,0 +1,2 @@ +dependencies: + - setup_connection_aws_ssm diff --git a/tests/integration/targets/connection_aws_ssm_fedora_netcat/runme.sh b/tests/integration/targets/connection_aws_ssm_fedora_netcat/runme.sh new file mode 100755 index 00000000000..ea812282d8c --- /dev/null +++ b/tests/integration/targets/connection_aws_ssm_fedora_netcat/runme.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +PLAYBOOK_DIR=$(pwd) +set -eux + +CMD_ARGS=("$@") + +# Destroy Environment +cleanup() { + + cd "${PLAYBOOK_DIR}" + ansible-playbook aws_ssm_integration_test_teardown.yml "${CMD_ARGS[@]}" + +} + +trap "cleanup" EXIT + +# Setup Environment +ansible-playbook aws_ssm_integration_test_setup.yml "$@" + +# Export the AWS Keys +set +x +. ./aws-env-vars.sh +set -x + +# Execute Integration tests +ansible-playbook test.yml -i "${PLAYBOOK_DIR}/ssm_inventory" "$@" diff --git a/tests/integration/targets/connection_aws_ssm_fedora_netcat/test.yml b/tests/integration/targets/connection_aws_ssm_fedora_netcat/test.yml new file mode 100644 index 00000000000..5786f1e737c --- /dev/null +++ b/tests/integration/targets/connection_aws_ssm_fedora_netcat/test.yml @@ -0,0 +1,45 @@ +- name: Deploy web server on ec2 instance using SSM + hosts: aws_ssm + gather_facts: true + + vars: + server_content: | + Enable SysAdmin Demo: + Ansible Profiling with Callback Plugin + Custom Web Page + + tasks: + - name: Install httpd package + ansible.builtin.dnf: + name: httpd + state: present + become: true + + - name: Start and enable httpd service + ansible.builtin.service: + name: httpd + enabled: true + state: started + become: true + + - name: Create a custom index.html file + ansible.builtin.copy: + dest: /var/www/html/index.html + content: "{{ server_content }}" + become: true + + - name: Ping Web server + ansible.builtin.get_url: + url: "http://localhost:80" + dest: /tmp/server.txt + + - name: Fetch file from remote host + ansible.builtin.fetch: + src: "/tmp/server.txt" + dest: "/tmp/ansible_ssm_server.txt" + flat: true + + - name: Validate server content + ansible.builtin.assert: + that: + - lookup('file', '/tmp/ansible_ssm_server.txt', rstrip=false) == server_content diff --git a/tests/integration/targets/setup_connection_aws_ssm/defaults/main.yml b/tests/integration/targets/setup_connection_aws_ssm/defaults/main.yml index 602738ce28e..30ea0a9860f 100644 --- a/tests/integration/targets/setup_connection_aws_ssm/defaults/main.yml +++ b/tests/integration/targets/setup_connection_aws_ssm/defaults/main.yml @@ -4,10 +4,10 @@ instance_type: t3.micro ami_details: fedora: owner: 125523088429 - name: 'Fedora-Cloud-Base-41-1.2.x86_64*' + name: 'Fedora-Cloud-Base-*.x86_64-41-*' user_data: | #!/bin/sh - sudo dnf install -y https://s3.amazonaws.com/ec2-downloads-windows/SSMAgent/latest/linux_amd64/amazon-ssm-agent.rpm + sudo dnf install -y https://s3.amazonaws.com/ec2-downloads-windows/SSMAgent/latest/linux_amd64/amazon-ssm-agent.rpm nc sudo systemctl start amazon-ssm-agent os_type: linux centos: diff --git a/tests/integration/targets/setup_connection_aws_ssm/tasks/main.yml b/tests/integration/targets/setup_connection_aws_ssm/tasks/main.yml index 7403f5ff960..81029973938 100644 --- a/tests/integration/targets/setup_connection_aws_ssm/tasks/main.yml +++ b/tests/integration/targets/setup_connection_aws_ssm/tasks/main.yml @@ -10,23 +10,18 @@ session_token: '{{ security_token | default(omit) }}' region: '{{ aws_region }}' block: - - - name: get ARN of calling user - amazon.aws.aws_caller_info: - register: aws_caller_info - - name: setup connection argments fact ansible.builtin.include_tasks: 'connection_args.yml' - name: Ensure IAM instance role exists amazon.aws.iam_role: - name: "ansible-test-{{tiny_prefix}}-aws-ssm-role" - assume_role_policy_document: "{{ lookup('file','ec2-trust-policy.json') }}" + name: "ansible-test-{{ tiny_prefix }}-aws-ssm-role" + assume_role_policy_document: "{{ lookup('file', 'ec2-trust-policy.json') }}" state: present - create_instance_profile: yes + create_instance_profile: true managed_policy: - - AmazonSSMManagedInstanceCore - wait: True + - AmazonSSMManagedInstanceCore + wait: true register: role_output - name: Lookup AMI configuration @@ -40,14 +35,15 @@ name: '{{ ami_configuration.name }}' register: ec2_amis when: - - ami_configuration.name | default(False) + - '"name" in ami_configuration' + - ami_configuration.name != "" - name: AMI Lookup (SSM Parameter) - when: - - ami_configuration.ssm_parameter | default(False) - block: - - ansible.builtin.set_fact: + ansible.builtin.set_fact: ssm_amis: "{{ lookup('amazon.aws.ssm_parameter', ami_configuration.ssm_parameter, **connection_args) }}" + when: + - '"ssm_parameter" in ami_configuration' + - ami_configuration.ssm_parameter != "" - name: Set facts with latest AMIs vars: @@ -112,6 +108,7 @@ name: "{{ s3_bucket_name }}" region: "{{ s3_bucket_region | default(omit)}}" register: s3_output + when: use_s3_bucket | default(True) | bool - name: setup encryption ansible.builtin.include_tasks: 'encryption.yml' @@ -141,6 +138,7 @@ src: ec2_instance_vars_to_delete.yml.j2 ignore_errors: true when: + - instance_output is defined - instance_output is successful - name: Create IAM Role vars_to_delete.yml @@ -148,19 +146,20 @@ dest: "{{ playbook_dir }}/iam_role_vars_to_delete.yml" src: iam_role_vars_to_delete.yml.j2 when: + - role_output is defined - role_output is successful - ignore_errors: true - name: Create S3.yml ansible.builtin.template: dest: "{{ playbook_dir }}/s3_vars_to_delete.yml" src: s3_vars_to_delete.yml.j2 when: - - s3_output is successful - ignore_errors: true + - s3_output is defined + - '"name" in s3_output' - name: Create SSM vars_to_delete.yml ansible.builtin.template: dest: "{{ playbook_dir }}/ssm_vars_to_delete.yml" src: ssm_vars_to_delete.yml.j2 - ignore_errors: true + when: + - use_ssm_document | default(False) diff --git a/tests/integration/targets/setup_connection_aws_ssm/tasks/ssm_document.yml b/tests/integration/targets/setup_connection_aws_ssm/tasks/ssm_document.yml index 4acc7f21858..5da1d2affff 100644 --- a/tests/integration/targets/setup_connection_aws_ssm/tasks/ssm_document.yml +++ b/tests/integration/targets/setup_connection_aws_ssm/tasks/ssm_document.yml @@ -1,11 +1,4 @@ --- -- block: - - name: Create custom SSM document - command: "aws ssm create-document --content file://{{ role_path }}/files/ssm-document.json --name {{ ssm_document_name }} --document-type Session" - environment: "{{ connection_env }}" - always: - - name: Create SSM vars_to_delete.yml - template: - dest: "{{ playbook_dir }}/ssm_vars_to_delete.yml" - src: ssm_vars_to_delete.yml.j2 - ignore_errors: true +- name: Create custom SSM document + command: "aws ssm create-document --content file://{{ role_path }}/files/ssm-document.json --name {{ ssm_document_name }} --document-type Session" + environment: "{{ connection_env }}" diff --git a/tests/integration/targets/setup_connection_aws_ssm/templates/inventory-combined.aws_ssm.j2 b/tests/integration/targets/setup_connection_aws_ssm/templates/inventory-combined.aws_ssm.j2 index d558c866589..ecb71866442 100644 --- a/tests/integration/targets/setup_connection_aws_ssm/templates/inventory-combined.aws_ssm.j2 +++ b/tests/integration/targets/setup_connection_aws_ssm/templates/inventory-combined.aws_ssm.j2 @@ -31,7 +31,12 @@ ansible_connection=community.aws.aws_ssm ansible_aws_ssm_plugin=/usr/local/sessionmanagerplugin/bin/session-manager-plugin ansible_python_interpreter={{ os_python_path | default('/usr/bin/python3') }} local_tmp=/tmp/ansible-local-{{ tiny_prefix }} +{% if use_s3_bucket | default(True) %} ansible_aws_ssm_bucket_name={{ s3_bucket_name }} +{% endif %} +{% if host_port_number | default(False) %} +ansible_aws_ssm_host_port_number={{ host_port_number }} +{% endif %} {% if s3_addressing_style | default(False) %} ansible_aws_ssm_s3_addressing_style={{ s3_addressing_style }} {% endif %} diff --git a/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py b/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py index 191dc6d56cb..07225ee494a 100644 --- a/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py +++ b/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py @@ -44,6 +44,8 @@ def mock_get_option(key): # Mock the _initialize_ssm_client and _initialize_s3_client methods conn._initialize_ssm_client = MagicMock() + conn._use_bucket = MagicMock() + conn._use_bucket.return_value = True conn._init_clients() @@ -197,6 +199,8 @@ def test_plugins_connection_aws_ssm_put_file(self, mock_ospe): conn._connect = MagicMock() conn._file_transport_command = MagicMock() conn._file_transport_command.return_value = (0, "stdout", "stderr") + conn._use_bucket = MagicMock() + conn._use_bucket.return_value = True conn.put_file("/in/file", "/out/file") def test_plugins_connection_aws_ssm_fetch_file(self): @@ -206,6 +210,8 @@ def test_plugins_connection_aws_ssm_fetch_file(self): conn._connect = MagicMock() conn._file_transport_command = MagicMock() conn._file_transport_command.return_value = (0, "stdout", "stderr") + conn._use_bucket = MagicMock() + conn._use_bucket.return_value = True conn.fetch_file("/in/file", "/out/file") @patch("subprocess.check_output") From 1295cd509bcb92f31edf94c5fc33fd8b76dab5e7 Mon Sep 17 00:00:00 2001 From: aubin bikouo Date: Thu, 3 Apr 2025 18:07:17 +0200 Subject: [PATCH 2/5] Some testing and rebase --- plugins/connection/aws_ssm.py | 34 ++- plugins/plugin_utils/ssm/common.py | 214 +++++++++++++++ plugins/plugin_utils/ssm/transport.py | 150 +++++++++++ plugins/plugin_utils/ssm_file_transfer.py | 307 ---------------------- 4 files changed, 379 insertions(+), 326 deletions(-) create mode 100644 plugins/plugin_utils/ssm/common.py create mode 100644 plugins/plugin_utils/ssm/transport.py delete mode 100644 plugins/plugin_utils/ssm_file_transfer.py diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index 5bf28340098..34e7c6aa438 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -383,7 +383,7 @@ from ansible.utils.display import Display from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 -from ansible_collections.community.aws.plugins.plugin_utils.ssm_file_transfer import PortForwardingFileTransferManager +from ansible_collections.community.aws.plugins.plugin_utils.ssm.transport import PortForwardingFileTransportManager from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager @@ -555,25 +555,21 @@ def _init_clients(self) -> None: ) self._s3_client = self.s3_manager._s3_client else: - self._initialize_file_transfer_manager() + # Initialize file transport with port forwarding + self._filetransfer_mgr = PortForwardingFileTransportManager( + host_port_number=self.get_option("host_port_number"), + local_port_number=self.get_option("local_port_number"), + verbosity_display=self.verbosity_display, + ) - def _initialize_file_transfer_manager(self) -> None: - ssm_timeout = self.get_option("ssm_timeout") - region_name = self.get_option("region") - profile_name = self.get_option("profile") or "" - host_port = self.get_option("host_port_number") - local_port = self.get_option("local_port_number") - self._filetransfer_mgr = PortForwardingFileTransferManager( - self.host, - ssm_client=self._client, - instance_id=self.instance_id, - executable=self.get_executable(), - ssm_timeout=ssm_timeout, - region_name=region_name, - profile_name=profile_name, - host_port=host_port, - local_port=local_port, - ) + self._filetransfer_mgr.start_session( + client=self._client, + instance_id=self.instance_id, + executable=self.get_executable(), + region=self.get_option("region"), + profile=self.get_option("profile"), + ssm_timeout=self.get_option("ssm_timeout"), + ) def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None: """ diff --git a/plugins/plugin_utils/ssm/common.py b/plugins/plugin_utils/ssm/common.py new file mode 100644 index 00000000000..95f421fcb77 --- /dev/null +++ b/plugins/plugin_utils/ssm/common.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import json +import os +import pty +import select +import subprocess +import time +from functools import wraps +from typing import Any +from typing import Callable +from typing import Dict +from typing import NoReturn +from typing import Optional +from typing import TypedDict +from typing import Union + +from ansible.errors import AnsibleConnectionFailure +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text + + +def ssm_retry(func: Any) -> Any: + """ + Decorator to retry in the case of a connection failure + Will retry if: + * an exception is caught + Will not retry if + * remaining_tries is <2 + * retries limit reached + """ + + @wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> Any: + for attr in ("reconnection_retries", "verbosity_display"): + if not hasattr(self, attr): + raise AnsibleError(f"Cannot decorate this function with 'ssm_retry', missing attribute '{attr}'") + remaining_tries = int(getattr(self, "reconnection_retries")) + 1 + cmd_summary = f"{args[0]}..." + for attempt in range(remaining_tries): + try: + return_tuple = func(self, *args, **kwargs) + self.verbosity_display(4, f"ssm_retry: (success) {to_text(return_tuple)}") + break + + except (AnsibleConnectionFailure, Exception) as e: + if attempt == remaining_tries - 1: + raise + pause = 2**attempt - 1 + pause = min(pause, 30) + + if isinstance(e, AnsibleConnectionFailure): + msg = f"ssm_retry: attempt: {attempt}, cmd ({cmd_summary}), pausing for {pause} seconds" + else: + msg = ( + f"ssm_retry: attempt: {attempt}, caught exception({e})" + f"from cmd ({cmd_summary}),pausing for {pause} seconds" + ) + + self.verbosity_display(2, msg) + + time.sleep(pause) + + # Do not attempt to reuse the existing session on retries + # This will cause the SSM session to be completely restarted, + # as well as reinitializing the boto3 clients + if hasattr(self, "close"): + getattr(self, "close")() + + continue + + return return_tuple + + return wrapped + + +class CommandResult(TypedDict): + """ + A dictionary that contains the executed command results. + """ + + returncode: int + stdout_combined: str + stderr_combined: str + + +class SSMDisplay: + def __init__(self, verbosity_display: Callable[[int, str], None]): + self.verbosity_display = verbosity_display + + +class StdoutPoller: + def __init__(self, session: Any, stdout: Any, poller: Any, timeout: int) -> None: + self._stdout = stdout + self._poller = poller + self._session = session + self._timeout = timeout + + def readline(self): + return self._stdout.readline() + + def has_data(self) -> bool: + return bool(self._poller.poll(self._timeout)) + + def read_stdout(self, length: int = 1024) -> str: + return self._stdout.read(length).decode("utf-8") + + def stdin_write(self, value: str) -> None: + self._session.stdin.write(value) + + def poll(self) -> NoReturn: + start = round(time.time()) + yield self.has_data() + while self._session.poll() is None: + remaining = start + self._timeout - round(time.time()) + if remaining < 0: + raise AnsibleConnectionFailure("StdoutPoller timeout...") + yield self.has_data() + + def match_expr(self, expr: Union[str, callable]) -> str: + time_start = time.time() + content = "" + while (int(time.time()) - time_start) < self._timeout: + if self.poll(): + content += self.read_stdout() + if callable(expr): + if expr(content): + return content + elif expr in content: + return content + raise TimeoutError(f"Unable to match expr '{expr}' from content") + + def flush_stderr(self) -> str: + """read and return stderr with minimal blocking""" + + poll_stderr = select.poll() + poll_stderr.register(self._session.stderr, select.POLLIN) + stderr = "" + while self._session.poll() is None: + if not poll_stderr.poll(1): + break + line = self._session.stderr.readline() + stderr = stderr + line + return stderr + + +class SSMSessionManager(SSMDisplay): + def __init__( + self, + client: Any, + instance_id: str, + executable: str, + region: Optional[str], + profile: Optional[str], + ssm_timeout: int, + verbosity_display: Callable, + document_name: Optional[str] = None, + document_parameters: Optional[Dict] = None, + ): + super(SSMSessionManager, self).__init__(verbosity_display=verbosity_display) + + self._client = client + params = {"Target": instance_id} + if document_name: + params["DocumentName"] = document_name + if document_parameters: + params["Parameters"] = document_parameters + + try: + response = self._client.start_session(**params) + self._session_id = response["SessionId"] + self.verbosity_display(4, f"Start session - SessionId: {self._session_id}") + + cmd = [ + executable, + json.dumps(response), + region, + "StartSession", + profile, + json.dumps({"Target": instance_id}), + self._client.meta.endpoint_url, + ] + + stdout_r, stdout_w = pty.openpty() + self._session = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=stdout_w, + stderr=subprocess.PIPE, + close_fds=True, + bufsize=0, + ) + + os.close(stdout_w) + stdout = os.fdopen(stdout_r, "rb", 0) + self._poller = StdoutPoller( + session=self._session, + stdout=stdout, + poller=select.poll().register(stdout, select.POLLIN), + timeout=ssm_timeout, + ) + except Exception as e: + raise AnsibleConnectionFailure(f"failed to start session: {e}") + + def __del__(self): + if self._session_id: + self._display.vvvv(f"Terminating AWS Session: {self._session_id}") + self._client.terminate_session(SessionId=self._session_id) + if self._session: + self._display.vvvv("Terminating subprocess.Popen session") + self._session.terminate() diff --git a/plugins/plugin_utils/ssm/transport.py b/plugins/plugin_utils/ssm/transport.py new file mode 100644 index 00000000000..d03945c4358 --- /dev/null +++ b/plugins/plugin_utils/ssm/transport.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import random +import re +import socket +import string +import time +from typing import Any +from typing import Callable +from typing import Optional + +from .common import SSMDisplay +from .common import SSMSessionManager + + +class PortForwardingFileTransportManager(SSMDisplay): + DEFAULT_HOST_PORT = 80 + + def __init__( + self, host_port_number: Optional[int], local_port_number: Optional[str], verbosity_display: Callable + ) -> None: + super(PortForwardingFileTransportManager, self).__init__(verbosity_display=verbosity_display) + + self.host_port_number = host_port_number or self.DEFAULT_HOST_PORT + self.local_port_number = local_port_number + self._portforwarding_session = None + self._shell_session = None + + def start_session( + self, + client: Any, + instance_id: str, + executable: str, + region: Optional[str], + profile: Optional[str], + ssm_timeout: int, + ) -> None: + # Create Port forwarding Session + parameters = {} + if self.local_port_number: + parameters["localPortNumber"] = [str(self.local_port_number)] + if self.host_port_number: + parameters["portNumber"] = [str(self.host_port_number)] + + self._portforwarding_session = SSMSessionManager( + client=client, + document_name="AWS-StartPortForwardingSession", + document_parameters=parameters, + instance_id=instance_id, + executable=executable, + region=region, + profile=profile, + ssm_timeout=ssm_timeout, + verbosity_display=self.verbosity_display, + ) + + match_expr = re.compile(r"Port ([0-9]+) opened for sessionId") + content = self._portforwarding_session._poller.match_expr(expr=match_expr.search) + match = match_expr.search(content) + self.local_port_number = int(match.group(1)) + self.verbosity_display(4, f"SSM PORT FORWARDING - Local port '{self.local_port_number}'") + + # Start shell session + self._shell_session = SSMSessionManager( + client=client, + instance_id=instance_id, + executable=executable, + region=region, + profile=profile, + ssm_timeout=ssm_timeout, + verbosity_display=self.verbosity_display, + ) + + def _socket_connect(self, session: Any, port: int, host: str = "localhost", max_attempts=10) -> None: + """Connect to socket""" + for attempt in range(max_attempts): + try: + session.connect((host, port)) + break + except OSError: + if attempt == max_attempts - 1: + self.verbosity_display( + 4, f"SOCKET _CONNECT: Failed to intiate socket connection on '{host}:{port}'" + ) + raise + time.sleep(0.05) + + def _socket_read(self, port: int, out_path: Optional[str] = None) -> None: + self.verbosity_display(4, f"Read content from socket on port '{port}'...") + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as session: + session.settimeout(1) + self._socket_connect(session=session, port=port) + try: + with open(out_path, "wb") as fhandler: + while 1: + data = session.recv(1024) + if not data: + break + fhandler.write(data) + time.sleep(0.05) + except TimeoutError: + pass + self.verbosity_display(4, "Socket connection closed.") + + def _socket_write(self, port: int, in_path: str) -> None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as session: + self._socket_connect(session=session, port=port) + with open(in_path, "rb") as f: + session.sendall(f.read()) + + def put_file(self, in_path: str, out_path: str) -> None: + # Start listener on Remote host + mark_end = "".join([random.choice(string.ascii_lowercase + string.digits) for i in range(12)]) + put_cmd = f"sudo nc -l -v -p {str(self.host_port_number)} > {out_path}; printf '{mark_end}'" + put_cmd = (put_cmd + "\n").encode() + self.verbosity_display(4, f"Write command '{put_cmd}'") + self._shell_session._poller.flush_stderr() + self._shell_session._poller.stdin_write(put_cmd) + # Ensure server is listening + self._shell_session._poller.match_expr(expr="Listening") + self.verbosity_display(4, "Server is listening...") + + # Write data into socket + self._socket_write(port=self.local_port_number, in_path=in_path) + + # Ensure nc command has ended on remote host + self._shell_session._poller.match_expr(expr=mark_end) + + self.verbosity_display(4, f"End of polling, stderr = {self._shell_session._poller.flush_stderr()}") + + def fetch_file(self, in_path: str, out_path: str) -> None: + self._shell_session._poller.flush_stderr() + mark_end = "".join([random.choice(string.ascii_lowercase + string.digits) for i in range(12)]) + fetch_cmd = f"sudo nc -v -l {self.host_port_number} < {in_path}; printf '{mark_end}'" + fetch_cmd = (fetch_cmd + "\n").encode() + self.verbosity_display(4, f"Write command '{fetch_cmd}'") + self._shell_session._poller.stdin_write(fetch_cmd) + + # Ensure server is listening + self._shell_session._poller.match_expr(expr="Listening") + self.verbosity_display(4, "Server is listening...") + # Read data from socket + self._socket_read(port=self.local_port_number, out_path=out_path) + # Ensure nc command has ended on remote host + self._shell_session._poller.match_expr(expr=mark_end) + + self.verbosity_display(4, f"End of polling, stderr = {self._shell_session._poller.flush_stderr()}") diff --git a/plugins/plugin_utils/ssm_file_transfer.py b/plugins/plugin_utils/ssm_file_transfer.py deleted file mode 100644 index 0dc30eb16d7..00000000000 --- a/plugins/plugin_utils/ssm_file_transfer.py +++ /dev/null @@ -1,307 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright: Ansible Project -# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) - -import json -import os -import pty -import random -import re -import select -import socket -import string -import subprocess -import time -from typing import Any -from typing import Dict -from typing import List -from typing import NoReturn -from typing import Optional -from typing import Union - -from ansible.errors import AnsibleConnectionFailure -from ansible.module_utils._text import to_text -from ansible.utils.display import Display - -display = Display() - - -class ConnectionPluginDisplay: - def __init__(self, host: Optional[str]) -> None: - self._host_args = {} - if host: - self._host_args = {"host": host} - - def _display(self, f: callable, message: str) -> None: - f(to_text(message), **self._host_args) - - def v(self, message): - self._display(display.v, message) - - def vv(self, message): - self._display(display.vv, message) - - def vvv(self, message): - self._display(display.vvv, message) - - def vvvv(self, message): - self._display(display.vvvv, message) - - -class StdoutPoller: - def __init__(self, session: Any, stdout: Any, poller: Any, timeout: int) -> None: - self._stdout = stdout - self._poller = poller - self._session = session - self._timeout = timeout - - def readline(self): - return self._stdout.readline() - - def has_data(self) -> bool: - return bool(self._poller.poll(self._timeout)) - - def read_stdout(self, length: int = 1024) -> str: - return self._stdout.read(length).decode("utf-8") - - def stdin_write(self, value: str) -> None: - self._session.stdin.write(value) - - def poll(self) -> NoReturn: - start = round(time.time()) - yield self.has_data() - while self._session.poll() is None: - remaining = start + self._timeout - round(time.time()) - if remaining < 0: - raise AnsibleConnectionFailure("StdoutPoller timeout...") - yield self.has_data() - - def match_expr(self, expr: Union[str, callable]) -> str: - time_start = time.time() - content = "" - while (int(time.time()) - time_start) < self._timeout: - if self.poll(): - content += self.read_stdout() - if callable(expr): - if expr(content): - return content - elif expr in content: - return content - raise TimeoutError(f"Unable to match expr '{expr}' from content") - - def flush_stderr(self) -> str: - """read and return stderr with minimal blocking""" - - poll_stderr = select.poll() - poll_stderr.register(self._session.stderr, select.POLLIN) - stderr = "" - while self._session.poll() is None: - if not poll_stderr.poll(1): - break - line = self._session.stderr.readline() - stderr = stderr + line - return stderr - - -class AnsibleAwsSSMSession: - def __init__( - self, - display: ConnectionPluginDisplay, - ssm_client: Any, - instance_id: str, - executable: str, - ssm_timeout: int, - region_name: Optional[str], - profile_name: str, - document_name: Optional[str] = None, - parameters: Optional[Dict[str, List[str]]] = None, - ) -> None: - self._client = ssm_client - self._session = None - self._session_id = None - self._local_port = None - self._display = display - - params = {"Target": instance_id} - if document_name: - params["DocumentName"] = document_name - if parameters: - params["Parameters"] = parameters - - try: - response = self._client.start_session(**params) - self._session_id = response["SessionId"] - self._display.vvvv(f"Start session - SessionId: {self._session_id}") - - cmd = [ - executable, - json.dumps(response), - region_name, - "StartSession", - profile_name, - json.dumps({"Target": instance_id}), - self._client.meta.endpoint_url, - ] - - stdout_r, stdout_w = pty.openpty() - self._session = subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - stdout=stdout_w, - stderr=subprocess.PIPE, - close_fds=True, - bufsize=0, - ) - - os.close(stdout_w) - stdout = os.fdopen(stdout_r, "rb", 0) - self._poller = StdoutPoller( - session=self._session, - stdout=stdout, - poller=select.poll().register(stdout, select.POLLIN), - timeout=ssm_timeout, - ) - except Exception as e: - raise AnsibleConnectionFailure(f"failed to start session: {e}") - - def __del__(self): - if self._session_id: - self._display.vvvv(f"Terminating AWS Session: {self._session_id}") - self._client.terminate_session(SessionId=self._session_id) - if self._session: - self._display.vvvv("Terminating subprocess.Popen session") - self._session.terminate() - - -class PortForwardingFileTransferManager: - DEFAULT_HOST_PORT = 80 - - def __init__( - self, - host: Optional[str], - ssm_client: Any, - instance_id: str, - executable: str, - ssm_timeout: int, - region_name: Optional[str], - profile_name: str, - host_port: Optional[int], - local_port: Optional[int], - ) -> None: - self._client = ssm_client - self._session = None - self._instance_id = instance_id - self._session_id = None - self._poller = None - self._local_port = local_port - self._host_port = host_port or self.DEFAULT_HOST_PORT - self._display = ConnectionPluginDisplay(host=host) - - # Create Port forwarding Session - parameters = {} - if local_port: - parameters["localPortNumber"] = [str(local_port)] - if host_port: - parameters["portNumber"] = [str(host_port)] - self._portforwarding_session = AnsibleAwsSSMSession( - ssm_client=ssm_client, - instance_id=instance_id, - executable=executable, - ssm_timeout=ssm_timeout, - region_name=region_name, - profile_name=profile_name, - document_name="AWS-StartPortForwardingSession", - parameters=parameters, - display=self._display, - ) - - match_expr = re.compile(r"Port ([0-9]+) opened for sessionId") - content = self._portforwarding_session._poller.match_expr(expr=match_expr.search) - match = match_expr.search(content) - self._local_port = int(match.group(1)) - self._display.vvvv(f"SSM PORT FORWARDING - Local port '{self._local_port}'") - - # Start shell session - self._shell_session = AnsibleAwsSSMSession( - ssm_client=ssm_client, - instance_id=instance_id, - executable=executable, - ssm_timeout=ssm_timeout, - region_name=region_name, - profile_name=profile_name, - display=self._display, - ) - - def _socket_connect(self, session: Any, port: int, host: str = "localhost", max_attempts=10) -> None: - """Connect to socket""" - for attempt in range(max_attempts): - try: - session.connect((host, port)) - break - except OSError: - if attempt == max_attempts - 1: - self._display.vvvv(f"SOCKET _CONNECT: Failed to intiate socket connection on '{host}:{port}'") - raise - time.sleep(0.05) - - def _socket_read(self, port: int, out_path: Optional[str] = None) -> None: - self._display.vvvv(f"Read content from socket on port '{port}'...") - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as session: - session.settimeout(1) - self._socket_connect(session=session, port=port) - try: - with open(out_path, "wb") as fhandler: - while 1: - data = session.recv(1024) - if not data: - break - fhandler.write(data) - time.sleep(0.05) - except TimeoutError: - pass - self._display.vvvv("Socket connection closed.") - - def _socket_write(self, port: int, in_path: str) -> None: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as session: - self._socket_connect(session=session, port=port) - with open(in_path, "rb") as f: - session.sendall(f.read()) - - def put_file(self, in_path: str, out_path: str) -> None: - # Start listener on Remote host - mark_end = "".join([random.choice(string.ascii_lowercase + string.digits) for i in range(12)]) - put_cmd = f"sudo nc -l -v -p {str(self._host_port)} > {out_path}; printf '{mark_end}'" - put_cmd = (put_cmd + "\n").encode() - self._display.vvvv(f"Write command '{put_cmd}'") - self._shell_session._poller.flush_stderr() - self._shell_session._poller.stdin_write(put_cmd) - # Ensure server is listening - self._shell_session._poller.match_expr(expr="Listening") - self._display.vvvv("Server is listening...") - - # Write data into socket - self._socket_write(port=self._local_port, in_path=in_path) - - # Ensure nc command has ended on remote host - self._shell_session._poller.match_expr(expr=mark_end) - - self._display.vvvv(f"End of polling, stderr = {self._shell_session._poller.flush_stderr()}") - - def fetch_file(self, in_path: str, out_path: str) -> None: - self._shell_session._poller.flush_stderr() - mark_end = "".join([random.choice(string.ascii_lowercase + string.digits) for i in range(12)]) - fetch_cmd = f"sudo nc -v -l {self._host_port} < {in_path}; printf '{mark_end}'" - fetch_cmd = (fetch_cmd + "\n").encode() - self._display.vvvv(f"Write command '{fetch_cmd}'") - self._shell_session._poller.stdin_write(fetch_cmd) - - # Ensure server is listening - self._shell_session._poller.match_expr(expr="Listening") - self._display.vvvv("Server is listening...") - # Read data from socket - self._socket_read(port=self._local_port, out_path=out_path) - # Ensure nc command has ended on remote host - self._shell_session._poller.match_expr(expr=mark_end) - - self._display.vvvv(f"End of polling, stderr = {self._shell_session._poller.flush_stderr()}") From ebaacba52f314189deec94d4a6f083a5e70c5ab1 Mon Sep 17 00:00:00 2001 From: aubin bikouo Date: Mon, 7 Apr 2025 10:59:14 +0200 Subject: [PATCH 3/5] Working on Turbo --- plugins/connection/aws_ssm.py | 253 +++------------------------- plugins/plugin_utils/ssm/command.py | 192 +++++++++++++++++++++ plugins/plugin_utils/ssm/common.py | 65 +------ 3 files changed, 216 insertions(+), 294 deletions(-) create mode 100644 plugins/plugin_utils/ssm/command.py diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index 34e7c6aa438..34b3350664f 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -351,16 +351,13 @@ import pty import random import re -import select import string import subprocess import time from functools import wraps from typing import Any from typing import Dict -from typing import Iterator from typing import List -from typing import NoReturn from typing import Optional from typing import Tuple from typing import TypedDict @@ -379,11 +376,11 @@ from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.common.process import get_bin_path from ansible.plugins.connection import ConnectionBase -from ansible.plugins.shell.powershell import _common_args from ansible.utils.display import Display from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 from ansible_collections.community.aws.plugins.plugin_utils.ssm.transport import PortForwardingFileTransportManager +from ansible_collections.community.aws.plugins.plugin_utils.ssm.command import CommandManager from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager @@ -440,35 +437,6 @@ def wrapped(self, *args: Any, **kwargs: Any) -> Any: return wrapped -def chunks(lst: List, n: int) -> Iterator[List[Any]]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i:i + n] # fmt: skip - - -def filter_ansi(line: str, is_windows: bool) -> str: - """Remove any ANSI terminal control codes. - - :param line: The input line. - :param is_windows: Whether the output is coming from a Windows host. - :returns: The result line. - """ - line = to_text(line) - - if is_windows: - osc_filter = re.compile(r"\x1b\][^\x07]*\x07") - line = osc_filter.sub("", line) - ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]") - line = ansi_filter.sub("", line) - - # Replace or strip sequence (at terminal width) - line = line.replace("\r\r\n", "\n") - if len(line) == 201: - line = line[:-1] - - return line - - class CommandResult(TypedDict): """ A dictionary that contains the executed command results. @@ -493,10 +461,11 @@ class Connection(ConnectionBase): _client = None _s3_client = None _session = None - _stdout = None _session_id = "" _timeout = False _filetransfer_mgr = None + _command_mgr = None + _instance_id = None MARK_LENGTH = 26 def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -506,9 +475,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise AnsibleError(missing_required_lib("boto3")) self.host = self._play_context.remote_addr - self._instance_id = None - self._polling_obj = None - self._has_timeout = False if getattr(self._shell, "SHELL_FAMILY", "") == "powershell": self.delegate = None @@ -645,7 +611,7 @@ def get_executable(self) -> str: raise AnsibleError(str(e)) return ssm_plugin_executable - def start_session(self): + def start_session(self) -> None: """start ssm session""" self.verbosity_display(3, f"ESTABLISH SSM CONNECTION TO: {self.instance_id}") @@ -661,6 +627,7 @@ def start_session(self): start_session_args["DocumentName"] = document_name response = self._client.start_session(**start_session_args) self._session_id = response["SessionId"] + self.verbosity_display(4, f"SSM CONNECTION ID: {self._session_id}") region_name = self.get_option("region") profile_name = self.get_option("profile") or "" @@ -687,129 +654,29 @@ def start_session(self): ) os.close(stdout_w) - self._stdout = os.fdopen(stdout_r, "rb", 0) + self._command_mgr = CommandManager( + shell=self._shell, + session=self._session, + stdout_r=stdout_r, + ssm_timeout=self.get_option("ssm_timeout"), + verbosity_display=self.verbosity_display, + ) # For non-windows Hosts: Ensure the session has started, and disable command echo and prompt. self._prepare_terminal() - self.verbosity_display(4, f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable - - return self._session - - def poll_stdout(self, timeout: int = 1000) -> bool: - """Polls the stdout file descriptor. - - :param timeout: Specifies the length of time in milliseconds which the system will wait. - :returns: A boolean to specify the polling result - """ - if self._polling_obj is None: - self._polling_obj = select.poll() - self._polling_obj.register(self._stdout, select.POLLIN) - return bool(self._polling_obj.poll(timeout)) - - def poll(self, label: str, cmd: str) -> NoReturn: - """Poll session to retrieve content from stdout. - - :param label: A label for the display (EXEC, PRE...) - :param cmd: The command being executed - """ - start = round(time.time()) - yield self.poll_stdout() - timeout = self.get_option("ssm_timeout") - while self._session.poll() is None: - remaining = start + timeout - round(time.time()) - self.verbosity_display(4, f"{label} remaining: {remaining} second(s)") - if remaining < 0: - self._has_timeout = True - raise AnsibleConnectionFailure(f"{label} command '{cmd}' timeout on host: {self.instance_id}") - yield self.poll_stdout() - - def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]: - """Interact with session. - Read stdout between the markers until 'mark_end' is reached. - - :param cmd: The command being executed. - :param mark_start: The marker which starts the output. - :param mark_begin: The begin marker. - :param mark_end: The end marker. - :returns: A tuple with the return code, the stdout and the stderr content. - """ - # Read stdout between the markers - stdout = "" - win_line = "" - begin = False - returncode = None - for poll_result in self.poll("EXEC", cmd): - if not poll_result: - continue - - line = filter_ansi(self._stdout.readline(), self.is_windows) - self.verbosity_display(4, f"EXEC stdout line: \n{line}") - - if not begin and self.is_windows: - win_line = win_line + line - line = win_line - - if mark_start in line: - begin = True - if not line.startswith(mark_start): - stdout = "" - continue - if begin: - if mark_end in line: - self.verbosity_display(4, f"POST_PROCESS: \n{to_text(stdout)}") - returncode, stdout = self._post_process(stdout, mark_begin) - self.verbosity_display(4, f"POST_PROCESSED: \n{to_text(stdout)}") - break - stdout = stdout + line - - # see https://github.com/pylint-dev/pylint/issues/8909) - return (returncode, stdout, self._flush_stderr(self._session)) # pylint: disable=unreachable - - @staticmethod - def generate_mark() -> str: - """Generates a random string of characters to delimit SSM CLI commands""" - mark = "".join([random.choice(string.ascii_letters) for i in range(Connection.MARK_LENGTH)]) - return mark - @_ssm_retry def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> Tuple[int, str, str]: """When running a command on the SSM host, uses generate_mark to get delimiting strings""" super().exec_command(cmd, in_data=in_data, sudoable=sudoable) - - self.verbosity_display(3, f"EXEC: {to_text(cmd)}") - - mark_begin = self.generate_mark() - if self.is_windows: - mark_start = mark_begin + " $LASTEXITCODE" - else: - mark_start = mark_begin - mark_end = self.generate_mark() - - # Wrap command in markers accordingly for the shell used - cmd = self._wrap_command(cmd, mark_start, mark_end) - - self._flush_stderr(self._session) - - for chunk in chunks(cmd, 1024): - self._session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict")) - - return self.exec_communicate(cmd, mark_start, mark_begin, mark_end) + return self._command_mgr.exec_command(cmd, instance_id=self.instance_id, region_name=self.get_option("region") or "us-east-1") def _ensure_ssm_session_has_started(self) -> None: """Ensure the SSM session has started on the host. We poll stdout until we match the following string 'Starting session with SessionId' """ - stdout = "" - for poll_result in self.poll("START SSM SESSION", "start_session"): - if poll_result: - stdout += to_text(self._stdout.read(1024)) - self.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}") - match = str(stdout).find("Starting session with SessionId") - if match != -1: - self.verbosity_display(4, "START SSM SESSION startup output received") - break + self._command_mgr.poller.match_expr(expr="Starting session with SessionId") def _disable_prompt_command(self) -> None: """Disable prompt command from the host""" @@ -822,15 +689,8 @@ def _disable_prompt_command(self) -> None: # Send command self.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}") - self._session.stdin.write(disable_prompt_cmd) - - stdout = "" - for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd): - if poll_result: - stdout += to_text(self._stdout.read(1024)) - self.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}") - if disable_prompt_reply.search(stdout): - break + self._command_mgr.poller.stdin_write(disable_prompt_cmd) + self._command_mgr.poller.match_expr(expr=disable_prompt_reply.search) def _disable_echo_command(self) -> None: """Disable echo command from the host""" @@ -838,16 +698,8 @@ def _disable_echo_command(self) -> None: # Send command self.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}") - self._session.stdin.write(disable_echo_cmd) - - stdout = "" - for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd): - if poll_result: - stdout += to_text(self._stdout.read(1024)) - self.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}") - match = str(stdout).find("stty -echo") - if match != -1: - break + self._command_mgr.poller.stdin_write(disable_echo_cmd) + self._command_mgr.poller.match_expr(expr="stty -echo") def _prepare_terminal(self) -> None: """perform any one-time terminal settings""" @@ -866,73 +718,6 @@ def _prepare_terminal(self) -> None: self.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable - def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str: - """Wrap command so stdout and status can be extracted""" - - if self.is_windows: - if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"): - cmd = self._shell._encode_script(cmd, preserve_rc=True) - cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n" - else: - cmd = ( - f"printf '%s\\n' '{mark_start}';\n" - f"echo | {cmd};\n" - f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n" - ) # fmt: skip - - self.verbosity_display(4, f"_wrap_command: \n'{to_text(cmd)}'") - return cmd - - def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]: - """extract command status and strip unwanted lines""" - - if not self.is_windows: - # Get command return code - returncode = int(stdout.splitlines()[-2]) - - # Throw away final lines - for _x in range(0, 3): - stdout = stdout[:stdout.rfind('\n')] # fmt: skip - - return (returncode, stdout) - - # Windows is a little more complex - # Value of $LASTEXITCODE will be the line after the mark - trailer = stdout[stdout.rfind(mark_begin):] # fmt: skip - last_exit_code = trailer.splitlines()[1] - if last_exit_code.isdigit: - returncode = int(last_exit_code) - else: - returncode = -1 - # output to keep will be before the mark - stdout = stdout[:stdout.rfind(mark_begin)] # fmt: skip - - # If the return code contains #CLIXML (like a progress bar) remove it - clixml_filter = re.compile(r"#<\sCLIXML\s") - stdout = clixml_filter.sub("", stdout) - - # If it looks like JSON remove any newlines - if stdout.startswith("{"): - stdout = stdout.replace("\n", "") - - return (returncode, stdout) - - def _flush_stderr(self, session_process) -> str: - """read and return stderr with minimal blocking""" - - poll_stderr = select.poll() - poll_stderr.register(session_process.stderr, select.POLLIN) - stderr = "" - - while session_process.poll() is None: - if not poll_stderr.poll(1): - break - line = session_process.stderr.readline() - self.verbosity_display(4, f"stderr line: {to_text(line)}") - stderr = stderr + line - - return stderr - def _get_boto_client(self, service, region_name=None, profile_name=None, endpoint_url=None): """Gets a boto3 client based on the STS token""" @@ -1155,7 +940,7 @@ def close(self) -> None: """terminate the connection""" if self._session_id: self.verbosity_display(3, f"CLOSING SSM CONNECTION TO: {self.instance_id}") - if self._has_timeout: + if self._command_mgr.has_timeout: self._session.terminate() else: cmd = b"\nexit\n" diff --git a/plugins/plugin_utils/ssm/command.py b/plugins/plugin_utils/ssm/command.py new file mode 100644 index 00000000000..b50e62f14be --- /dev/null +++ b/plugins/plugin_utils/ssm/command.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import os +import random +import re +import select +import string +from typing import Any +from typing import Callable +from typing import Iterator +from typing import List +from typing import Tuple + +from ansible.module_utils._text import to_bytes +from ansible.module_utils._text import to_text +from ansible.plugins.shell.powershell import _common_args + +from .common import SSMDisplay +from .common import StdoutPoller +from .turbo_client import turbo_exec_command + + +@staticmethod +def generate_mark() -> str: + """Generates a random string of characters to delimit SSM CLI commands""" + mark = "".join([random.choice(string.ascii_letters) for i in range(26)]) + return mark + + +def chunks(lst: List, n: int) -> Iterator[List[Any]]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] # fmt: skip + + +def filter_ansi(line: str, is_windows: bool) -> str: + """Remove any ANSI terminal control codes. + + :param line: The input line. + :param is_windows: Whether the output is coming from a Windows host. + :returns: The result line. + """ + line = to_text(line) + + if is_windows: + osc_filter = re.compile(r"\x1b\][^\x07]*\x07") + line = osc_filter.sub("", line) + ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]") + line = ansi_filter.sub("", line) + + # Replace or strip sequence (at terminal width) + line = line.replace("\r\r\n", "\n") + if len(line) == 201: + line = line[:-1] + + return line + + +class CommandManager(SSMDisplay): + def __init__(self, shell: Any, session: Any, stdout_r: Any, ssm_timeout: int, verbosity_display: Callable) -> None: + super(CommandManager, self).__init__(verbosity_display=verbosity_display) + self._shell = shell + stdout = os.fdopen(stdout_r, "rb", 0) + poller = select.poll() + poller.register(stdout, select.POLLIN) + self._poller = StdoutPoller(session=session, stdout=stdout, poller=poller, timeout=ssm_timeout) + self.is_windows = bool(getattr(self._shell, "SHELL_FAMILY", "") == "powershell") + + @property + def poller(self) -> Any: + return self._poller + + @property + def has_timeout(self) -> bool: + return self._poller._has_timeout + + def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str: + """Wrap command so stdout and status can be extracted""" + if self.is_windows: + if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"): + cmd = self._shell._encode_script(cmd, preserve_rc=True) + cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n" + else: + cmd = ( + f"printf '%s\\n' '{mark_start}';\n" + f"echo | {cmd};\n" + f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n" + ) # fmt: skip + + self.verbosity_display(4, f"_wrap_command: \n'{to_text(cmd)}'") + return cmd + + def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]: + """extract command status and strip unwanted lines""" + + if not self.is_windows: + # Get command return code + returncode = int(stdout.splitlines()[-2]) + + # Throw away final lines + for _x in range(0, 3): + stdout = stdout[:stdout.rfind('\n')] # fmt: skip + + return (returncode, stdout) + + # Windows is a little more complex + # Value of $LASTEXITCODE will be the line after the mark + trailer = stdout[stdout.rfind(mark_begin):] # fmt: skip + last_exit_code = trailer.splitlines()[1] + if last_exit_code.isdigit: + returncode = int(last_exit_code) + else: + returncode = -1 + # output to keep will be before the mark + stdout = stdout[:stdout.rfind(mark_begin)] # fmt: skip + + # If the return code contains #CLIXML (like a progress bar) remove it + clixml_filter = re.compile(r"#<\sCLIXML\s") + stdout = clixml_filter.sub("", stdout) + + # If it looks like JSON remove any newlines + if stdout.startswith("{"): + stdout = stdout.replace("\n", "") + + return (returncode, stdout) + + def exec_communicate(self, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]: + """Interact with session. + Read stdout between the markers until 'mark_end' is reached. + + :param cmd: The command being executed. + :param mark_start: The marker which starts the output. + :param mark_begin: The begin marker. + :param mark_end: The end marker. + :returns: A tuple with the return code, the stdout and the stderr content. + """ + # Read stdout between the markers + stdout = "" + win_line = "" + begin = False + returncode = None + for poll_result in self._poller.poll(): + if not poll_result: + continue + + line = filter_ansi(self._poller.readline(), self.is_windows) + self.verbosity_display(4, f"EXEC stdout line: \n{line}") + + if not begin and self.is_windows: + win_line = win_line + line + line = win_line + + if mark_start in line: + begin = True + if not line.startswith(mark_start): + stdout = "" + continue + if begin: + if mark_end in line: + self.verbosity_display(4, f"POST_PROCESS: \n{to_text(stdout)}") + returncode, stdout = self._post_process(stdout, mark_begin) + self.verbosity_display(4, f"POST_PROCESSED: \n{to_text(stdout)}") + break + stdout = stdout + line + + # see https://github.com/pylint-dev/pylint/issues/8909) + return (returncode, stdout, self._poller.flush_stderr()) # pylint: disable=unreachable + + def exec_command(self, cmd: str, instance_id: str, region_name: str) -> Tuple[int, str, str]: + self.verbosity_display(3, f"EXEC: {to_text(cmd)}") + + turbo_result = turbo_exec_command(command=cmd, instance_id=instance_id, region_name=region_name, verbosity_display=self.verbosity_display) + self.verbosity_display(4, f"TURBO COMMAND RESULT: {turbo_result}") + + mark_begin = generate_mark() + if self.is_windows: + mark_start = mark_begin + " $LASTEXITCODE" + else: + mark_start = mark_begin + mark_end = generate_mark() + + # Wrap command in markers accordingly for the shell used + cmd = self._wrap_command(cmd, mark_start, mark_end) + + self._poller.flush_stderr() + for chunk in chunks(cmd, 1024): + self._poller.stdin_write(to_bytes(chunk, errors="surrogate_or_strict")) + + return self.exec_communicate(mark_start, mark_begin, mark_end) diff --git a/plugins/plugin_utils/ssm/common.py b/plugins/plugin_utils/ssm/common.py index 95f421fcb77..085dfc9647e 100644 --- a/plugins/plugin_utils/ssm/common.py +++ b/plugins/plugin_utils/ssm/common.py @@ -9,7 +9,6 @@ import select import subprocess import time -from functools import wraps from typing import Any from typing import Callable from typing import Dict @@ -19,62 +18,6 @@ from typing import Union from ansible.errors import AnsibleConnectionFailure -from ansible.errors import AnsibleError -from ansible.module_utils._text import to_text - - -def ssm_retry(func: Any) -> Any: - """ - Decorator to retry in the case of a connection failure - Will retry if: - * an exception is caught - Will not retry if - * remaining_tries is <2 - * retries limit reached - """ - - @wraps(func) - def wrapped(self, *args: Any, **kwargs: Any) -> Any: - for attr in ("reconnection_retries", "verbosity_display"): - if not hasattr(self, attr): - raise AnsibleError(f"Cannot decorate this function with 'ssm_retry', missing attribute '{attr}'") - remaining_tries = int(getattr(self, "reconnection_retries")) + 1 - cmd_summary = f"{args[0]}..." - for attempt in range(remaining_tries): - try: - return_tuple = func(self, *args, **kwargs) - self.verbosity_display(4, f"ssm_retry: (success) {to_text(return_tuple)}") - break - - except (AnsibleConnectionFailure, Exception) as e: - if attempt == remaining_tries - 1: - raise - pause = 2**attempt - 1 - pause = min(pause, 30) - - if isinstance(e, AnsibleConnectionFailure): - msg = f"ssm_retry: attempt: {attempt}, cmd ({cmd_summary}), pausing for {pause} seconds" - else: - msg = ( - f"ssm_retry: attempt: {attempt}, caught exception({e})" - f"from cmd ({cmd_summary}),pausing for {pause} seconds" - ) - - self.verbosity_display(2, msg) - - time.sleep(pause) - - # Do not attempt to reuse the existing session on retries - # This will cause the SSM session to be completely restarted, - # as well as reinitializing the boto3 clients - if hasattr(self, "close"): - getattr(self, "close")() - - continue - - return return_tuple - - return wrapped class CommandResult(TypedDict): @@ -98,17 +41,18 @@ def __init__(self, session: Any, stdout: Any, poller: Any, timeout: int) -> None self._poller = poller self._session = session self._timeout = timeout + self._has_timeout = False def readline(self): return self._stdout.readline() - def has_data(self) -> bool: - return bool(self._poller.poll(self._timeout)) + def has_data(self, timeout: int = 1000) -> bool: + return bool(self._poller.poll(timeout)) def read_stdout(self, length: int = 1024) -> str: return self._stdout.read(length).decode("utf-8") - def stdin_write(self, value: str) -> None: + def stdin_write(self, value: Union[str | bytes]) -> None: self._session.stdin.write(value) def poll(self) -> NoReturn: @@ -117,6 +61,7 @@ def poll(self) -> NoReturn: while self._session.poll() is None: remaining = start + self._timeout - round(time.time()) if remaining < 0: + self._has_timeout = True raise AnsibleConnectionFailure("StdoutPoller timeout...") yield self.has_data() From e879e830fb5196ac30d8e105f2050455cd600580 Mon Sep 17 00:00:00 2001 From: aubin bikouo Date: Mon, 7 Apr 2025 14:51:17 +0200 Subject: [PATCH 4/5] add turbo elements --- plugins/plugin_utils/ssm/turbo_client.py | 110 +++++++++++++++ plugins/plugin_utils/ssm/turbo_server.py | 168 +++++++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 plugins/plugin_utils/ssm/turbo_client.py create mode 100644 plugins/plugin_utils/ssm/turbo_server.py diff --git a/plugins/plugin_utils/ssm/turbo_client.py b/plugins/plugin_utils/ssm/turbo_client.py new file mode 100644 index 00000000000..644874690da --- /dev/null +++ b/plugins/plugin_utils/ssm/turbo_client.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- + +# Copyright: Contributors to the Ansible project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import json +import os +import pickle +import socket +import subprocess +import sys +import time +from contextlib import contextmanager +from typing import Callable +from typing import Dict + +from ansible.errors import AnsibleRuntimeError + +from .common import SSMDisplay + + +def create_socket_path(instance_id: str, region_name: str) -> str: + return os.path.join( + os.environ["HOME"], ".ansible", "_".join(["connection_aws_ssm_turbo", instance_id, region_name]) + ) + + +class SSMTurboSocket(SSMDisplay): + def __init__(self, instance_id, region_name, ttl, verbosity_display): + super(SSMTurboSocket, self).__init__(verbosity_display) + self._socket_path = create_socket_path(instance_id, region_name) + self.verbosity_display(4, f">>> SSM TURBO SOCKET PATH = {self._socket_path}") + self._ttl = ttl + self._socket = None + + def bind(self): + running = False + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + for attempt in range(100, -1, -1): + try: + self._socket.connect(self._socket_path) + return True + except (ConnectionRefusedError, FileNotFoundError): + if not running: + running = self.start_server() + if attempt == 0: + raise + time.sleep(0.01) + + def start_server(self): + env = os.environ + parameters = ["--fork", "--socket-path", self._socket_path, "--ttl", str(self._ttl)] + + command = [sys.executable] + ansiblez_path = sys.path[0] + env.update({"PYTHONPATH": ansiblez_path}) + command += [ + "-m", + "ansible_collections.community.aws.plugins.plugin_utils.ssm.turbo_server", + ] + # parent_dir = os.path.dirname(__file__) + # server_path = os.path.join(parent_dir, "server.py") + # command += [server_path] + self.verbosity_display(4, f">>> SSM TURBO SOCKET COMMAND = '{command + parameters}'") + p = subprocess.Popen( + command + parameters, + env=env, + close_fds=True, + ) + result = p.communicate() + self.verbosity_display(4, f">>> SSM TURBO SOCKET COMMAND Pid = '{p.pid}' (result = {result})") + return p.pid + + def communicate(self, command, wait_sleep=0.01): + encoded_data = pickle.dumps(command) + self._socket.sendall(encoded_data) + self._socket.shutdown(socket.SHUT_WR) + raw_answer = b"" + while True: + b = self._socket.recv((1024 * 1024)) + if not b: + break + raw_answer += b + time.sleep(wait_sleep) + try: + result = json.loads(raw_answer.decode()) + return result + except json.decoder.JSONDecodeError: + raise AnsibleRuntimeError(f"Cannot decode exec_command answer: {raw_answer}") + + def close(self): + if self._socket: + self._socket.close() + + +@contextmanager +def connect(instance_id, region_name, ttl, verbosity_display): + turbo_socket = SSMTurboSocket( + instance_id=instance_id, region_name=region_name, ttl=ttl, verbosity_display=verbosity_display + ) + try: + turbo_socket.bind() + yield turbo_socket + finally: + turbo_socket.close() + + +def turbo_exec_command(command: str, instance_id: str, region_name: str, verbosity_display: Callable, ttl=10) -> Dict: + with connect(instance_id, region_name, ttl=ttl, verbosity_display=verbosity_display) as turbo_socket: + return turbo_socket.communicate(command=command) diff --git a/plugins/plugin_utils/ssm/turbo_server.py b/plugins/plugin_utils/ssm/turbo_server.py new file mode 100644 index 00000000000..dcb91098817 --- /dev/null +++ b/plugins/plugin_utils/ssm/turbo_server.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- + +# Copyright: Contributors to the Ansible project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import argparse +import asyncio +import json +import os +import pickle +import signal +import sys +import traceback +import uuid +from datetime import datetime + + +def fork_process(): + """ + This function performs the double fork process to detach from the + parent process and execute. + """ + pid = os.fork() + + if pid == 0: + fd = os.open(os.devnull, os.O_RDWR) + + # clone stdin/out/err + for num in range(3): + if fd != num: + os.dup2(fd, num) + + if fd not in range(3): + os.close(fd) + + pid = os.fork() + if pid > 0: + os._exit(0) + + # get new process session and detach + sid = os.setsid() + if sid == -1: + raise Exception("Unable to detach session while daemonizing") + + # avoid possible problems with cwd being removed + os.chdir("/") + + pid = os.fork() + if pid > 0: + sys.exit(0) # pylint: disable=ansible-bad-function + else: + sys.exit(0) # pylint: disable=ansible-bad-function + return pid + + +def trace_value(msg): + with open("/Users/aubinho/work/aws/connection_ssm/turbo/trace.txt", "a+") as f: + f.write(f"{msg}\n") + + +class SSMTurboMode: + def __init__(self, socket_path, ttl): + self.socket_path = socket_path + self.ttl = ttl + self.jobs_ongoing = {} + + async def ghost_killer(self): + while True: + await asyncio.sleep(self.ttl) + running_jobs = { + job_id: start_date + for job_id, start_date in self.jobs_ongoing.items() + if (datetime.now() - start_date).total_seconds() < 3600 + } + if running_jobs: + continue + self.stop() + + async def handle(self, reader, writer): + result = None + self._watcher.cancel() + self._watcher = self.loop.create_task(self.ghost_killer()) + job_id = str(uuid.uuid4()) + self.jobs_ongoing[job_id] = datetime.now() + raw_data = await reader.read() + + if not raw_data: + return + + command = pickle.loads(raw_data) + + def _terminate(result): + writer.write(json.dumps(result).encode()) + writer.close() + + result = { + "returncode": 1, + "stdout": "some content received from server", + "stderr": "some flush error", + "command": command, + } + _terminate(result) + del self.jobs_ongoing[job_id] + + def handle_exception(self, loop, context): + e = context.get("exception") + traceback.print_exception(type(e), e, e.__traceback__) + self.stop() + + def start(self): + # for python versions >= Python3.11 + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.add_signal_handler(signal.SIGTERM, self.stop) + self.loop.set_exception_handler(self.handle_exception) + self._watcher = self.loop.create_task(self.ghost_killer()) + + self.loop.run_until_complete(asyncio.start_unix_server(self.handle, path=self.socket_path)) + self.loop.run_forever() + + # print(f"sys hex version: {sys.hexversion}") + # self.loop = asyncio.get_event_loop() + # self.loop.add_signal_handler(signal.SIGTERM, self.stop) + # self.loop.set_exception_handler(self.handle_exception) + # self._watcher = self.loop.create_task(self.ghost_killer()) + + # trace_value(f"Socket path: {self.socket_path}") + + # if os.path.exists(self.socket_path): + # trace_value("Socket path exist going to remove it") + # os.remove(self.socket_path) + # trace_value("loop.run_until_complete...") + # self.loop.run_until_complete(asyncio.start_unix_server(self.handle, path=self.socket_path)) + # trace_value("chmod socket...") + # os.chmod(self.socket_path, 0o666) + # self.loop.run_forever() + + # if sys.hexversion >= 0x30A00B1: + # # py3.10 drops the loop argument of create_task. + # self.loop.create_task( + # asyncio.start_unix_server(self.handle, path=self.socket_path) + # ) + # else: + # self.loop.create_task( + # asyncio.start_unix_server( + # self.handle, path=self.socket_path, loop=self.loop + # ) + # ) + # self.loop.run_forever() + + def stop(self): + os.unlink(self.socket_path) + self.loop.stop() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Start a background daemon.") + parser.add_argument("--socket-path") + parser.add_argument("--ttl", default=15, type=int) + parser.add_argument("--fork", action="store_true") + + args = parser.parse_args() + if args.fork: + trace_value("forking process") + fork_process() + + server = SSMTurboMode(socket_path=args.socket_path, ttl=args.ttl) + server.start() From f1454c107c1b19cc069205f22f680fd86ac64b70 Mon Sep 17 00:00:00 2001 From: aubin bikouo Date: Tue, 8 Apr 2025 11:07:44 +0200 Subject: [PATCH 5/5] update before tests --- plugins/connection/aws_ssm.py | 15 +- plugins/plugin_utils/ssm/command.py | 45 ++++-- plugins/plugin_utils/ssm/turbo_client.py | 68 ++++++-- plugins/plugin_utils/ssm/turbo_server.py | 190 +++++++++++++++++++++-- 4 files changed, 268 insertions(+), 50 deletions(-) diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index 34b3350664f..490123f59d3 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -380,7 +380,8 @@ from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 from ansible_collections.community.aws.plugins.plugin_utils.ssm.transport import PortForwardingFileTransportManager -from ansible_collections.community.aws.plugins.plugin_utils.ssm.command import CommandManager +from ansible_collections.community.aws.plugins.plugin_utils.ssm.command import CommandManager, encode_script +from ansible_collections.community.aws.plugins.plugin_utils.ssm.turbo_client import turbo_exec_command from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager @@ -493,8 +494,8 @@ def _connect(self) -> Any: """connect to the host via ssm""" self._play_context.remote_user = getpass.getuser() - if not self._session_id: - self.start_session() + # if not self._session_id: + # self.start_session() return self def _init_clients(self) -> None: @@ -655,7 +656,7 @@ def start_session(self) -> None: os.close(stdout_w) self._command_mgr = CommandManager( - shell=self._shell, + is_windows=self.is_windows, session=self._session, stdout_r=stdout_r, ssm_timeout=self.get_option("ssm_timeout"), @@ -670,7 +671,11 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> """When running a command on the SSM host, uses generate_mark to get delimiting strings""" super().exec_command(cmd, in_data=in_data, sudoable=sudoable) - return self._command_mgr.exec_command(cmd, instance_id=self.instance_id, region_name=self.get_option("region") or "us-east-1") + encoded_cmd = encode_script(self._shell, cmd) + if self._command_mgr: + return self._command_mgr.exec_command(encoded_cmd) + else: + return turbo_exec_command(self, encoded_cmd) def _ensure_ssm_session_has_started(self) -> None: """Ensure the SSM session has started on the host. We poll stdout diff --git a/plugins/plugin_utils/ssm/command.py b/plugins/plugin_utils/ssm/command.py index b50e62f14be..7515917783f 100644 --- a/plugins/plugin_utils/ssm/command.py +++ b/plugins/plugin_utils/ssm/command.py @@ -3,24 +3,40 @@ # Copyright: Ansible Project # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +import argparse +import asyncio +import json import os +import pickle +import pty import random import re import select +import signal import string +import subprocess +import sys +import traceback +import uuid +from datetime import datetime +from functools import wraps from typing import Any from typing import Callable from typing import Iterator from typing import List from typing import Tuple +try: + import boto3 +except ImportError: + pass + from ansible.module_utils._text import to_bytes from ansible.module_utils._text import to_text from ansible.plugins.shell.powershell import _common_args from .common import SSMDisplay from .common import StdoutPoller -from .turbo_client import turbo_exec_command @staticmethod @@ -59,15 +75,25 @@ def filter_ansi(line: str, is_windows: bool) -> str: return line +def encode_script(shell: Any, cmd: str) -> str: + result = cmd + if getattr(shell, "SHELL_FAMILY", "") == "powershell" and not cmd.startswith( + " ".join(_common_args) + " -EncodedCommand" + ): + result = shell._encode_script(cmd, preserve_rc=True) + return result + + class CommandManager(SSMDisplay): - def __init__(self, shell: Any, session: Any, stdout_r: Any, ssm_timeout: int, verbosity_display: Callable) -> None: + def __init__( + self, is_windows: bool, session: Any, stdout_r: Any, ssm_timeout: int, verbosity_display: Callable + ) -> None: super(CommandManager, self).__init__(verbosity_display=verbosity_display) - self._shell = shell stdout = os.fdopen(stdout_r, "rb", 0) poller = select.poll() poller.register(stdout, select.POLLIN) self._poller = StdoutPoller(session=session, stdout=stdout, poller=poller, timeout=ssm_timeout) - self.is_windows = bool(getattr(self._shell, "SHELL_FAMILY", "") == "powershell") + self.is_windows = is_windows @property def poller(self) -> Any: @@ -80,8 +106,6 @@ def has_timeout(self) -> bool: def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str: """Wrap command so stdout and status can be extracted""" if self.is_windows: - if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"): - cmd = self._shell._encode_script(cmd, preserve_rc=True) cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n" else: cmd = ( @@ -127,7 +151,7 @@ def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]: return (returncode, stdout) - def exec_communicate(self, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]: + def _exec_communicate(self, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]: """Interact with session. Read stdout between the markers until 'mark_end' is reached. @@ -169,12 +193,9 @@ def exec_communicate(self, mark_start: str, mark_begin: str, mark_end: str) -> T # see https://github.com/pylint-dev/pylint/issues/8909) return (returncode, stdout, self._poller.flush_stderr()) # pylint: disable=unreachable - def exec_command(self, cmd: str, instance_id: str, region_name: str) -> Tuple[int, str, str]: + def exec_command(self, cmd: str) -> Tuple[int, str, str]: self.verbosity_display(3, f"EXEC: {to_text(cmd)}") - turbo_result = turbo_exec_command(command=cmd, instance_id=instance_id, region_name=region_name, verbosity_display=self.verbosity_display) - self.verbosity_display(4, f"TURBO COMMAND RESULT: {turbo_result}") - mark_begin = generate_mark() if self.is_windows: mark_start = mark_begin + " $LASTEXITCODE" @@ -189,4 +210,4 @@ def exec_command(self, cmd: str, instance_id: str, region_name: str) -> Tuple[in for chunk in chunks(cmd, 1024): self._poller.stdin_write(to_bytes(chunk, errors="surrogate_or_strict")) - return self.exec_communicate(mark_start, mark_begin, mark_end) + return self._exec_communicate(mark_start, mark_begin, mark_end) diff --git a/plugins/plugin_utils/ssm/turbo_client.py b/plugins/plugin_utils/ssm/turbo_client.py index 644874690da..8c3a90a9b32 100644 --- a/plugins/plugin_utils/ssm/turbo_client.py +++ b/plugins/plugin_utils/ssm/turbo_client.py @@ -11,7 +11,7 @@ import sys import time from contextlib import contextmanager -from typing import Callable +from typing import Any from typing import Dict from ansible.errors import AnsibleRuntimeError @@ -26,11 +26,12 @@ def create_socket_path(instance_id: str, region_name: str) -> str: class SSMTurboSocket(SSMDisplay): - def __init__(self, instance_id, region_name, ttl, verbosity_display): - super(SSMTurboSocket, self).__init__(verbosity_display) - self._socket_path = create_socket_path(instance_id, region_name) + def __init__(self, conn_plugin: Any): + super(SSMTurboSocket, self).__init__(conn_plugin.verbosity_display) + self._region = conn_plugin.get_option("region") or "us-east-1" + self._socket_path = create_socket_path(conn_plugin.instance_id, self._region) self.verbosity_display(4, f">>> SSM TURBO SOCKET PATH = {self._socket_path}") - self._ttl = ttl + self.conn_plugin = conn_plugin self._socket = None def bind(self): @@ -47,9 +48,44 @@ def bind(self): raise time.sleep(0.01) + def _mask_command(self, command: str) -> str: + if self.conn_plugin.get_option("access_key_id"): + command = command.replace(self.conn_plugin.get_option("access_key_id"), "*****") + if self.conn_plugin.get_option("secret_access_key"): + command = command.replace(self.conn_plugin.get_option("secret_access_key"), "*****") + if self.conn_plugin.get_option("session_token"): + command = command.replace(self.conn_plugin.get_option("session_token"), "*****") + return command + def start_server(self): env = os.environ - parameters = ["--fork", "--socket-path", self._socket_path, "--ttl", str(self._ttl)] + parameters = [ + "--fork", + "--socket-path", + self._socket_path, + "--region", + self._region, + "--executable", + self.conn_plugin.get_executable(), + ] + + pairing_options = { + "--instance-id": "instance_id", + "--ssm-timeout": "ssm_timeout", + "--reconnection-retries": "reconnection_retries", + "--access-key-id": "access_key_id", + "--secret-access-key": "secret_access_key", + "--session-token": "session_token", + "--profile": "profile", + "--ssm-document": "ssm_document", + "--is-windows": "is_windows", + } + for opt, attr in pairing_options.items(): + if hasattr(self.conn_plugin, attr): + if opt_value := getattr(self.conn_plugin, attr): + parameters.extend([opt, str(opt_value)]) + elif opt_value := self.conn_plugin.get_option(attr): + parameters.extend([opt, str(opt_value)]) command = [sys.executable] ansiblez_path = sys.path[0] @@ -61,14 +97,15 @@ def start_server(self): # parent_dir = os.path.dirname(__file__) # server_path = os.path.join(parent_dir, "server.py") # command += [server_path] - self.verbosity_display(4, f">>> SSM TURBO SOCKET COMMAND = '{command + parameters}'") + displayed_command = self._mask_command(" ".join(command + parameters)) + self.verbosity_display(4, f">>> SSM TURBO SOCKET COMMAND = '{displayed_command}'") p = subprocess.Popen( command + parameters, env=env, close_fds=True, ) - result = p.communicate() - self.verbosity_display(4, f">>> SSM TURBO SOCKET COMMAND Pid = '{p.pid}' (result = {result})") + p.communicate() + self.verbosity_display(4, f">>> SSM TURBO SOCKET COMMAND Pid = '{p.pid}'") return p.pid def communicate(self, command, wait_sleep=0.01): @@ -94,10 +131,8 @@ def close(self): @contextmanager -def connect(instance_id, region_name, ttl, verbosity_display): - turbo_socket = SSMTurboSocket( - instance_id=instance_id, region_name=region_name, ttl=ttl, verbosity_display=verbosity_display - ) +def connect(conn_plugin: Any): + turbo_socket = SSMTurboSocket(conn_plugin) try: turbo_socket.bind() yield turbo_socket @@ -105,6 +140,7 @@ def connect(instance_id, region_name, ttl, verbosity_display): turbo_socket.close() -def turbo_exec_command(command: str, instance_id: str, region_name: str, verbosity_display: Callable, ttl=10) -> Dict: - with connect(instance_id, region_name, ttl=ttl, verbosity_display=verbosity_display) as turbo_socket: - return turbo_socket.communicate(command=command) +def turbo_exec_command(conn_plugin: Any, encoded_cmd: str) -> Dict: + with connect(conn_plugin) as turbo_socket: + result = turbo_socket.communicate(command=encoded_cmd) + return result.get("returncode"), result.get("stdout"), result.get("stderr") diff --git a/plugins/plugin_utils/ssm/turbo_server.py b/plugins/plugin_utils/ssm/turbo_server.py index dcb91098817..652aa6268ec 100644 --- a/plugins/plugin_utils/ssm/turbo_server.py +++ b/plugins/plugin_utils/ssm/turbo_server.py @@ -8,11 +8,28 @@ import json import os import pickle +import pty +import random +import re import signal +import string +import subprocess import sys import traceback import uuid from datetime import datetime +from functools import wraps +from typing import Any +from typing import Tuple + +try: + import boto3 +except ImportError: + pass + +from ansible.module_utils._text import to_bytes + +from .command import CommandManager def fork_process(): @@ -53,15 +70,148 @@ def fork_process(): return pid -def trace_value(msg): - with open("/Users/aubinho/work/aws/connection_ssm/turbo/trace.txt", "a+") as f: - f.write(f"{msg}\n") +def _ensure_connect(func: Any, name: str) -> Any: + @wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> Any: + if getattr(self, name) is name: + getattr(self, f"_init_{name}")() + return func(self, *args, **kwargs) + + return wrapped + + +class CommandHandler: + def __init__(self, args: Any) -> None: + for attr in ( + "instance_id", + "ssm_timeout", + "reconnection_retries", + "access_key_id", + "secret_access_key", + "session_token", + "profile", + "region", + "ssm_document", + "executable", + "socket_path", + "is_windows", + ): + setattr(self, attr, getattr(args, attr)) + + self.client = None + self.session_id = None + self.session = None + self.file_handler = open(f"{self.socket_path}.trace", "a") + self.command_mgr = None + self.trace_level = 0 + if trace_level := getattr(args, "trace_level"): + self.trace_level = trace_level + + def __del__(self) -> None: + if self.session_id and self.client: + if self.command_mgr.has_timeout: + self.session.terminate() + else: + cmd = b"\nexit\n" + self.session.communicate(cmd) + self.client.terminate_session(SessionId=self.session_id) + if self.file_handler: + self.file_handler.close() + + def _init_client(self) -> Any: + if not self.client: + session_args = { + "aws_access_key_id": getattr(self, "access_key_id"), + "aws_secret_access_key": getattr(self, "secret_access_key"), + "aws_session_token": getattr(self, "session_token"), + "region_name": getattr(self, "region"), + } + + if (profile := getattr(self, "profile")) is not None: + session_args["profile_name"] = profile + session = boto3.session.Session(**session_args) + self.client = session.client("ssm") + + def _display(self, level: int, message: str) -> None: + if level >= self.trace_level: + message = f"[{''.join(['V' for i in range(level)])}] {message}\n" + self.file_handler(message) + self.file_handler.flush() + + @_ensure_connect(name="client") + def _init_session(self) -> None: + if not self.session: + ssm_session_args = {"Target": self.instance_id, "Parameters": {}} + if (document_name := getattr(self, "ssm_document")) is not None: + ssm_session_args["DocumentName"] = document_name + response = self._client.start_session(**ssm_session_args) + self._session_id = response["SessionId"] + self.verbosity_display(4, f"SSM CONNECTION ID: {self._session_id}") + + region_name = getattr(self, "region") + profile_name = getattr(self, "profile") or "" + cmd = [ + self.executable, + json.dumps(response), + region_name, + "StartSession", + profile_name, + json.dumps({"Target": self.instance_id}), + self._client.meta.endpoint_url, + ] + + self._display(4, f"SSM COMMAND: {(cmd)}") + + stdout_r, stdout_w = pty.openpty() + self.session = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=stdout_w, + stderr=subprocess.PIPE, + close_fds=True, + bufsize=0, + ) + + os.close(stdout_w) + self.command_mgr = CommandManager( + shell=self._shell, + session=self.session, + stdout_r=stdout_r, + ssm_timeout=self.ssm_timeout, + verbosity_display=self._display, + ) + + # For non-windows Hosts: Ensure the session has started, and disable command echo and prompt. + if not self.is_windows: + self.command_mgr.poller.match_expr(expr="Starting session with SessionId") + + # Disable echo command from the host + disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict") + self._display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}") + self.command_mgr.poller.stdin_write(disable_echo_cmd) + self.command_mgr.poller.match_expr(expr="stty -echo") + + # Disable prompt command from the host + end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)]) + disable_prompt_cmd = to_bytes( + "PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n", + errors="surrogate_or_strict", + ) + disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE) + self._display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}") + self.command_mgr.poller.stdin_write(disable_prompt_cmd) + self.command_mgr.poller.match_expr(expr=disable_prompt_reply.search) + + @_ensure_connect(name="session") + def exec_command(self, command: str) -> Tuple[int, str, str]: + return self.command_mgr.exec_command(command) class SSMTurboMode: - def __init__(self, socket_path, ttl): - self.socket_path = socket_path - self.ttl = ttl + def __init__(self, args: Any): + self.socket_path = args.socket_path + self.ttl = args.ttl + self.command_handler = CommandHandler(args) self.jobs_ongoing = {} async def ghost_killer(self): @@ -88,15 +238,16 @@ async def handle(self, reader, writer): return command = pickle.loads(raw_data) + returncode, stdout, stderr = self.command_handler.exec_command(command=command) def _terminate(result): writer.write(json.dumps(result).encode()) writer.close() result = { - "returncode": 1, - "stdout": "some content received from server", - "stderr": "some flush error", + "returncode": returncode, + "stdout": stdout, + "stderr": stderr, "command": command, } _terminate(result) @@ -124,14 +275,9 @@ def start(self): # self.loop.set_exception_handler(self.handle_exception) # self._watcher = self.loop.create_task(self.ghost_killer()) - # trace_value(f"Socket path: {self.socket_path}") - # if os.path.exists(self.socket_path): - # trace_value("Socket path exist going to remove it") # os.remove(self.socket_path) - # trace_value("loop.run_until_complete...") # self.loop.run_until_complete(asyncio.start_unix_server(self.handle, path=self.socket_path)) - # trace_value("chmod socket...") # os.chmod(self.socket_path, 0o666) # self.loop.run_forever() @@ -155,14 +301,24 @@ def stop(self): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Start a background daemon.") - parser.add_argument("--socket-path") + parser.add_argument("--socket-path", required=True) parser.add_argument("--ttl", default=15, type=int) parser.add_argument("--fork", action="store_true") + parser.add_argument("--instance-id", required=True) + parser.add_argument("--ssm-timeout", type=int, required=True) + parser.add_argument("--reconnection-retries", type=int, required=True) + parser.add_argument("--access-key-id", required=False) + parser.add_argument("--secret-access-key", required=False) + parser.add_argument("--session-token", required=False) + parser.add_argument("--profile", required=False) + parser.add_argument("--region", required=False) + parser.add_argument("--ssm-document", required=False) + parser.add_argument("--executable", required=True) + parser.add_argument("--is-windows", type=bool, default=False) args = parser.parse_args() if args.fork: - trace_value("forking process") fork_process() - server = SSMTurboMode(socket_path=args.socket_path, ttl=args.ttl) + server = SSMTurboMode(args=args) server.start()