332332import string
333333import subprocess
334334import time
335+ from functools import wraps
336+ from typing import Any
335337from typing import Dict
338+ from typing import Iterator
336339from typing import List
337340from typing import NoReturn
338341from typing import Optional
345348except ImportError :
346349 pass
347350
348- from functools import wraps
349-
350351from ansible .errors import AnsibleConnectionFailure
351352from ansible .errors import AnsibleError
352353from ansible .errors import AnsibleFileNotFound
360361
361362from ansible_collections .amazon .aws .plugins .module_utils .botocore import HAS_BOTO3
362363
364+ from ansible_collections .community .aws .plugins .plugin_utils .s3clientmanager import S3ClientManager
365+
363366display = Display ()
364367
365368
366- def _ssm_retry (func ) :
369+ def _ssm_retry (func : Any ) -> Any :
367370 """
368371 Decorator to retry in the case of a connection failure
369372 Will retry if:
@@ -374,7 +377,7 @@ def _ssm_retry(func):
374377 """
375378
376379 @wraps (func )
377- def wrapped (self , * args , ** kwargs ) :
380+ def wrapped (self , * args : Any , ** kwargs : Any ) -> Any :
378381 remaining_tries = int (self .get_option ("reconnection_retries" )) + 1
379382 cmd_summary = f"{ args [0 ]} ..."
380383 for attempt in range (remaining_tries ):
@@ -413,7 +416,7 @@ def wrapped(self, *args, **kwargs):
413416 return wrapped
414417
415418
416- def chunks (lst , n ) :
419+ def chunks (lst : List , n : int ) -> Iterator [ List [ Any ]] :
417420 """Yield successive n-sized chunks from lst."""
418421 for i in range (0 , len (lst ), n ):
419422 yield lst [i :i + n ] # fmt: skip
@@ -471,7 +474,7 @@ class Connection(ConnectionBase):
471474 _timeout = False
472475 MARK_LENGTH = 26
473476
474- def __init__ (self , * args , ** kwargs ) :
477+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
475478 super ().__init__ (* args , ** kwargs )
476479
477480 if not HAS_BOTO3 :
@@ -492,12 +495,11 @@ def __init__(self, *args, **kwargs):
492495 self ._shell_type = "powershell"
493496 self .is_windows = True
494497
495- def __del__ (self ):
498+ def __del__ (self ) -> None :
496499 self .close ()
497500
498- def _connect (self ):
501+ def _connect (self ) -> Any :
499502 """connect to the host via ssm"""
500-
501503 self ._play_context .remote_user = getpass .getuser ()
502504
503505 if not self ._session_id :
@@ -509,16 +511,23 @@ def _init_clients(self) -> None:
509511 Initializes required AWS clients (SSM and S3).
510512 Delegates client initialization to specialized methods.
511513 """
512-
513514 self ._vvvv ("INITIALIZE BOTO3 CLIENTS" )
514515 profile_name = self .get_option ("profile" ) or ""
515516 region_name = self .get_option ("region" )
516517
517- # Initialize SSM client
518- self ._initialize_ssm_client ( region_name , profile_name )
518+ # Initialize S3ClientManager
519+ self .s3_manager = S3ClientManager ( self )
519520
520521 # Initialize S3 client
521- self ._initialize_s3_client (profile_name )
522+ s3_endpoint_url , s3_region_name = self .s3_manager .get_bucket_endpoint ()
523+ self ._vvvv (f"SETUP BOTO3 CLIENTS: S3 { s3_endpoint_url } " )
524+ self .s3_manager .initialize_client (
525+ region_name = s3_region_name , endpoint_url = s3_endpoint_url , profile_name = profile_name
526+ )
527+ self ._s3_client = self .s3_manager ._s3_client
528+
529+ # Initialize SSM client
530+ self ._initialize_ssm_client (region_name , profile_name )
522531
523532 def _initialize_ssm_client (self , region_name : Optional [str ], profile_name : str ) -> None :
524533 """
@@ -538,84 +547,26 @@ def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str)
538547 profile_name = profile_name ,
539548 )
540549
541- def _initialize_s3_client (self , profile_name : str ) -> None :
542- """
543- Initializes the S3 client used for accessing S3 buckets.
544-
545- Args:
546- profile_name (str): AWS profile name for authentication.
547-
548- Returns:
549- None
550- """
551-
552- s3_endpoint_url , s3_region_name = self ._get_bucket_endpoint ()
553- self ._vvvv (f"SETUP BOTO3 CLIENTS: S3 { s3_endpoint_url } " )
554- self ._s3_client = self ._get_boto_client (
555- "s3" ,
556- region_name = s3_region_name ,
557- endpoint_url = s3_endpoint_url ,
558- profile_name = profile_name ,
559- )
560-
561- def _display (self , f , message ):
550+ def _display (self , f : Any , message : str ) -> None :
562551 if self .host :
563552 host_args = {"host" : self .host }
564553 else :
565554 host_args = {}
566555 f (to_text (message ), ** host_args )
567556
568- def _v (self , message ) :
557+ def _v (self , message : str ) -> None :
569558 self ._display (display .v , message )
570559
571- def _vv (self , message ) :
560+ def _vv (self , message : str ) -> None :
572561 self ._display (display .vv , message )
573562
574- def _vvv (self , message ) :
563+ def _vvv (self , message : str ) -> None :
575564 self ._display (display .vvv , message )
576565
577- def _vvvv (self , message ) :
566+ def _vvvv (self , message : str ) -> None :
578567 self ._display (display .vvvv , message )
579568
580- def _get_bucket_endpoint (self ):
581- """
582- Fetches the correct S3 endpoint and region for use with our bucket.
583- If we don't explicitly set the endpoint then some commands will use the global
584- endpoint and fail
585- (new AWS regions and new buckets in a region other than the one we're running in)
586- """
587-
588- region_name = self .get_option ("region" ) or "us-east-1"
589- profile_name = self .get_option ("profile" ) or ""
590- self ._vvvv ("_get_bucket_endpoint: S3 (global)" )
591- tmp_s3_client = self ._get_boto_client (
592- "s3" ,
593- region_name = region_name ,
594- profile_name = profile_name ,
595- )
596- # Fetch the location of the bucket so we can open a client against the 'right' endpoint
597- # This /should/ always work
598- head_bucket = tmp_s3_client .head_bucket (
599- Bucket = (self .get_option ("bucket_name" )),
600- )
601- bucket_region = head_bucket .get ("ResponseMetadata" , {}).get ("HTTPHeaders" , {}).get ("x-amz-bucket-region" , None )
602- if bucket_region is None :
603- bucket_region = "us-east-1"
604-
605- if self .get_option ("bucket_endpoint_url" ):
606- return self .get_option ("bucket_endpoint_url" ), bucket_region
607-
608- # Create another client for the region the bucket lives in, so we can nab the endpoint URL
609- self ._vvvv (f"_get_bucket_endpoint: S3 (bucket region) - { bucket_region } " )
610- s3_bucket_client = self ._get_boto_client (
611- "s3" ,
612- region_name = bucket_region ,
613- profile_name = profile_name ,
614- )
615-
616- return s3_bucket_client .meta .endpoint_url , s3_bucket_client .meta .region_name
617-
618- def reset (self ):
569+ def reset (self ) -> Any :
619570 """start a fresh ssm session"""
620571 self ._vvvv ("reset called on ssm connection" )
621572 self .close ()
@@ -885,7 +836,7 @@ def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
885836 self ._vvvv (f"_wrap_command: \n '{ to_text (cmd )} '" )
886837 return cmd
887838
888- def _post_process (self , stdout , mark_begin ) :
839+ def _post_process (self , stdout : str , mark_begin : str ) -> Tuple [ str , str ] :
889840 """extract command status and strip unwanted lines"""
890841
891842 if not self .is_windows :
@@ -919,7 +870,7 @@ def _post_process(self, stdout, mark_begin):
919870
920871 return (returncode , stdout )
921872
922- def _flush_stderr (self , session_process ):
873+ def _flush_stderr (self , session_process ) -> str :
923874 """read and return stderr with minimal blocking"""
924875
925876 poll_stderr = select .poll ()
@@ -935,15 +886,6 @@ def _flush_stderr(self, session_process):
935886
936887 return stderr
937888
938- def _get_url (self , client_method , bucket_name , out_path , http_method , extra_args = None ):
939- """Generate URL for get_object / put_object"""
940-
941- client = self ._s3_client
942- params = {"Bucket" : bucket_name , "Key" : out_path }
943- if extra_args is not None :
944- params .update (extra_args )
945- return client .generate_presigned_url (client_method , Params = params , ExpiresIn = 3600 , HttpMethod = http_method )
946-
947889 def _get_boto_client (self , service , region_name = None , profile_name = None , endpoint_url = None ):
948890 """Gets a boto3 client based on the STS token"""
949891
@@ -971,22 +913,9 @@ def _get_boto_client(self, service, region_name=None, profile_name=None, endpoin
971913 )
972914 return client
973915
974- def _escape_path (self , path ) :
916+ def _escape_path (self , path : str ) -> str :
975917 return path .replace ("\\ " , "/" )
976918
977- def _generate_encryption_settings (self ):
978- put_args = {}
979- put_headers = {}
980- if not self .get_option ("bucket_sse_mode" ):
981- return put_args , put_headers
982-
983- put_args ["ServerSideEncryption" ] = self .get_option ("bucket_sse_mode" )
984- put_headers ["x-amz-server-side-encryption" ] = self .get_option ("bucket_sse_mode" )
985- if self .get_option ("bucket_sse_mode" ) == "aws:kms" and self .get_option ("bucket_sse_kms_key_id" ):
986- put_args ["SSEKMSKeyId" ] = self .get_option ("bucket_sse_kms_key_id" )
987- put_headers ["x-amz-server-side-encryption-aws-kms-key-id" ] = self .get_option ("bucket_sse_kms_key_id" )
988- return put_args , put_headers
989-
990919 def _generate_commands (
991920 self ,
992921 bucket_name : str ,
@@ -1006,11 +935,11 @@ def _generate_commands(
1006935 :returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
1007936 """
1008937
1009- put_args , put_headers = self ._generate_encryption_settings ()
938+ put_args , put_headers = self .s3_manager . generate_encryption_settings ()
1010939 commands = []
1011940
1012- put_url = self ._get_url ("put_object" , bucket_name , s3_path , "PUT" , extra_args = put_args )
1013- get_url = self ._get_url ("get_object" , bucket_name , s3_path , "GET" )
941+ put_url = self .s3_manager . get_url ("put_object" , bucket_name , s3_path , "PUT" , extra_args = put_args )
942+ get_url = self .s3_manager . get_url ("get_object" , bucket_name , s3_path , "GET" )
1014943
1015944 if self .is_windows :
1016945 put_command_headers = "; " .join ([f"'{ h } ' = '{ v } '" for h , v in put_headers .items ()])
@@ -1150,7 +1079,7 @@ def _file_transport_command(
11501079 # Remove the files from the bucket after they've been transferred
11511080 client .delete_object (Bucket = bucket_name , Key = s3_path )
11521081
1153- def put_file (self , in_path , out_path ) :
1082+ def put_file (self , in_path : str , out_path : str ) -> Tuple [ int , str , str ] :
11541083 """transfer a file from local to remote"""
11551084
11561085 super ().put_file (in_path , out_path )
@@ -1161,15 +1090,15 @@ def put_file(self, in_path, out_path):
11611090
11621091 return self ._file_transport_command (in_path , out_path , "put" )
11631092
1164- def fetch_file (self , in_path , out_path ) :
1093+ def fetch_file (self , in_path : str , out_path : str ) -> Tuple [ int , str , str ] :
11651094 """fetch a file from remote to local"""
11661095
11671096 super ().fetch_file (in_path , out_path )
11681097
11691098 self ._vvv (f"FETCH { in_path } TO { out_path } " )
11701099 return self ._file_transport_command (in_path , out_path , "get" )
11711100
1172- def close (self ):
1101+ def close (self ) -> None :
11731102 """terminate the connection"""
11741103 if self ._session_id :
11751104 self ._vvv (f"CLOSING SSM CONNECTION TO: { self .instance_id } " )
0 commit comments