diff --git a/gems/aws-sdk-s3/CHANGELOG.md b/gems/aws-sdk-s3/CHANGELOG.md index 7fd43ebaea8..14a390ae6c8 100644 --- a/gems/aws-sdk-s3/CHANGELOG.md +++ b/gems/aws-sdk-s3/CHANGELOG.md @@ -1,6 +1,10 @@ Unreleased Changes ------------------ +* Feature - Add lightweight thread pool executor for multipart `download_file`, `upload_file` and `upload_stream`. + +* Feature - Add custom executor support for `Aws::S3::TransferManager`. + 1.199.1 (2025-09-25) ------------------ diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations.rb index c0dba64b79c..1ffb0c06892 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations.rb @@ -7,6 +7,7 @@ module S3 autoload :Encryption, 'aws-sdk-s3/encryption' autoload :EncryptionV2, 'aws-sdk-s3/encryption_v2' autoload :FilePart, 'aws-sdk-s3/file_part' + autoload :DefaultExecutor, 'aws-sdk-s3/default_executor' autoload :FileUploader, 'aws-sdk-s3/file_uploader' autoload :FileDownloader, 'aws-sdk-s3/file_downloader' autoload :LegacySigner, 'aws-sdk-s3/legacy_signer' diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb index 0a9a9b9d3ca..338eedb87cf 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb @@ -358,8 +358,8 @@ def public_url(options = {}) # {Client#complete_multipart_upload}, # and {Client#upload_part} can be provided. # - # @option options [Integer] :thread_count (10) The number of parallel - # multipart uploads + # @option options [Integer] :thread_count (10) The number of parallel multipart uploads. + # An additional thread is used internally for task coordination. # # @option options [Boolean] :tempfile (false) Normally read data is stored # in memory when building the parts in order to complete the underlying @@ -383,19 +383,18 @@ def public_url(options = {}) # @see Client#complete_multipart_upload # @see Client#upload_part def upload_stream(options = {}, &block) - uploading_options = options.dup + upload_opts = options.merge(bucket: bucket_name, key: key) + executor = DefaultExecutor.new(max_threads: upload_opts.delete(:thread_count)) uploader = MultipartStreamUploader.new( client: client, - thread_count: uploading_options.delete(:thread_count), - tempfile: uploading_options.delete(:tempfile), - part_size: uploading_options.delete(:part_size) + executor: executor, + tempfile: upload_opts.delete(:tempfile), + part_size: upload_opts.delete(:part_size) ) Aws::Plugins::UserAgent.metric('RESOURCE_MODEL') do - uploader.upload( - uploading_options.merge(bucket: bucket_name, key: key), - &block - ) + uploader.upload(upload_opts, &block) end + executor.shutdown true end deprecated(:upload_stream, use: 'Aws::S3::TransferManager#upload_stream', version: 'next major version') @@ -458,12 +457,18 @@ def upload_stream(options = {}, &block) # @see Client#complete_multipart_upload # @see Client#upload_part def upload_file(source, options = {}) - uploading_options = options.dup - uploader = FileUploader.new(multipart_threshold: uploading_options.delete(:multipart_threshold), client: client) + upload_opts = options.merge(bucket: bucket_name, key: key) + executor = DefaultExecutor.new(max_threads: upload_opts.delete(:thread_count)) + uploader = FileUploader.new( + client: client, + executor: executor, + multipart_threshold: upload_opts.delete(:multipart_threshold) + ) response = Aws::Plugins::UserAgent.metric('RESOURCE_MODEL') do - uploader.upload(source, uploading_options.merge(bucket: bucket_name, key: key)) + uploader.upload(source, upload_opts) end yield response if block_given? + executor.shutdown true end deprecated(:upload_file, use: 'Aws::S3::TransferManager#upload_file', version: 'next major version') @@ -512,10 +517,6 @@ def upload_file(source, options = {}) # # @option options [Integer] :thread_count (10) Customize threads used in the multipart download. # - # @option options [String] :version_id The object version id used to retrieve the object. - # - # @see https://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectVersioning.html ObjectVersioning - # # @option options [String] :checksum_mode ("ENABLED") # When `"ENABLED"` and the object has a stored checksum, it will be used to validate the download and will # raise an `Aws::Errors::ChecksumError` if checksum validation fails. You may provide a `on_checksum_validated` @@ -539,10 +540,13 @@ def upload_file(source, options = {}) # @see Client#get_object # @see Client#head_object def download_file(destination, options = {}) - downloader = FileDownloader.new(client: client) + download_opts = options.merge(bucket: bucket_name, key: key) + executor = DefaultExecutor.new(max_threads: download_opts.delete([:thread_count])) + downloader = FileDownloader.new(client: client, executor: executor) Aws::Plugins::UserAgent.metric('RESOURCE_MODEL') do - downloader.download(destination, options.merge(bucket: bucket_name, key: key)) + downloader.download(destination, download_opts) end + executor.shutdown true end deprecated(:download_file, use: 'Aws::S3::TransferManager#download_file', version: 'next major version') diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/default_executor.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/default_executor.rb new file mode 100644 index 00000000000..13a719f4397 --- /dev/null +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/default_executor.rb @@ -0,0 +1,103 @@ +# frozen_string_literal: true + +module Aws + module S3 + # @api private + class DefaultExecutor + DEFAULT_MAX_THREADS = 10 + RUNNING = :running + SHUTTING_DOWN = :shutting_down + SHUTDOWN = :shutdown + + def initialize(options = {}) + @max_threads = options[:max_threads] || DEFAULT_MAX_THREADS + @state = RUNNING + @queue = Queue.new + @pool = [] + @mutex = Mutex.new + end + + # Submits a task for execution. + # @param [Object] args Variable number of arguments to pass to the block + # @param [Proc] block The block to be executed + # @return [Boolean] Returns true if the task was submitted successfully + def post(*args, &block) + @mutex.synchronize do + raise 'Executor has been shutdown and is no longer accepting tasks' unless @state == RUNNING + + @queue << [args, block] + ensure_worker_available + end + true + end + + # Immediately terminates all worker threads and clears pending tasks. + # This is a forceful shutdown that doesn't wait for running tasks to complete. + # + # @return [Boolean] true when termination is complete + def kill + @mutex.synchronize do + @state = SHUTDOWN + @pool.each(&:kill) + @pool.clear + @queue.clear + end + true + end + + # Gracefully shuts down the executor, optionally with a timeout. + # Stops accepting new tasks and waits for running tasks to complete. + # + # @param timeout [Numeric, nil] Maximum time in seconds to wait for shutdown. + # If nil, waits indefinitely. If timeout expires, remaining threads are killed. + # @return [Boolean] true when shutdown is complete + def shutdown(timeout = nil) + @mutex.synchronize do + return true if @state == SHUTDOWN + + @state = SHUTTING_DOWN + @pool.size.times { @queue << :shutdown } + end + + if timeout + deadline = Time.now + timeout + @pool.each do |thread| + remaining = deadline - Time.now + break if remaining <= 0 + + thread.join([remaining, 0].max) + end + @pool.select(&:alive?).each(&:kill) + else + @pool.each(&:join) + end + + @mutex.synchronize do + @pool.clear + @state = SHUTDOWN + end + true + end + + private + + def ensure_worker_available + return unless @state == RUNNING + + @pool.select!(&:alive?) + @pool << spawn_worker if @pool.size < @max_threads + end + + def spawn_worker + Thread.new do + while (job = @queue.shift) + break if job == :shutdown + + args, block = job + block.call(*args) + end + end + end + end + end +end diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb index 517a06dabea..7e8f545cb01 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb @@ -8,184 +8,245 @@ module Aws module S3 # @api private class FileDownloader - MIN_CHUNK_SIZE = 5 * 1024 * 1024 MAX_PARTS = 10_000 + HEAD_OPTIONS = Set.new(Client.api.operation(:head_object).input.shape.member_names) + GET_OPTIONS = Set.new(Client.api.operation(:get_object).input.shape.member_names) def initialize(options = {}) @client = options[:client] || Client.new + @executor = options[:executor] end # @return [Client] attr_reader :client def download(destination, options = {}) - valid_types = [String, Pathname, File, Tempfile] - unless valid_types.include?(destination.class) - raise ArgumentError, "Invalid destination, expected #{valid_types.join(', ')} but got: #{destination.class}" - end - - @destination = destination - @mode = options.delete(:mode) || 'auto' - @thread_count = options.delete(:thread_count) || 10 - @chunk_size = options.delete(:chunk_size) - @on_checksum_validated = options.delete(:on_checksum_validated) - @progress_callback = options.delete(:progress_callback) - @params = options - validate! + validate_destination!(destination) + opts = build_download_opts(destination, options) + validate_opts!(opts) Aws::Plugins::UserAgent.metric('S3_TRANSFER') do - case @mode - when 'auto' then multipart_download - when 'single_request' then single_request - when 'get_range' - raise ArgumentError, 'In get_range mode, :chunk_size must be provided' unless @chunk_size - - resp = @client.head_object(@params) - multithreaded_get_by_ranges(resp.content_length, resp.etag) - else - raise ArgumentError, "Invalid mode #{@mode} provided, :mode should be single_request, get_range or auto" + case opts[:mode] + when 'auto' then multipart_download(opts) + when 'single_request' then single_request(opts) + when 'get_range' then range_request(opts) end end - File.rename(@temp_path, @destination) if @temp_path + File.rename(opts[:temp_path], destination) if opts[:temp_path] ensure - File.delete(@temp_path) if @temp_path && File.exist?(@temp_path) + cleanup_temp_file(opts) end private - def validate! - return unless @on_checksum_validated && !@on_checksum_validated.respond_to?(:call) + def build_download_opts(destination, opts) + { + destination: destination, + mode: opts.delete(:mode) || 'auto', + chunk_size: opts.delete(:chunk_size), + on_checksum_validated: opts.delete(:on_checksum_validated), + progress_callback: opts.delete(:progress_callback), + params: opts, + temp_path: nil + } + end + + def cleanup_temp_file(opts) + return unless opts + + temp_file = opts[:temp_path] + File.delete(temp_file) if temp_file && File.exist?(temp_file) + end + + def download_with_executor(part_list, total_size, opts) + download_attempts = 0 + completion_queue = Queue.new + abort_download = false + error = nil + progress = MultipartProgress.new(part_list, total_size, opts[:progress_callback]) + + while (part = part_list.shift) + break if abort_download + + download_attempts += 1 + @executor.post(part) do |p| + update_progress(progress, p) + resp = @client.get_object(p.params) + range = extract_range(resp.content_range) + validate_range(range, p.params[:range]) if p.params[:range] + write(resp.body, range, opts) + + execute_checksum_callback(resp, opts) + rescue StandardError => e + abort_download = true + error = e + ensure + completion_queue << :done + end + end + + download_attempts.times { completion_queue.pop } + raise error unless error.nil? + end + + def get_opts(opts) + GET_OPTIONS.each_with_object({}) { |k, h| h[k] = opts[k] if opts.key?(k) } + end + + def head_opts(opts) + HEAD_OPTIONS.each_with_object({}) { |k, h| h[k] = opts[k] if opts.key?(k) } + end + + def compute_chunk(chunk_size, file_size) + raise ArgumentError, ":chunk_size shouldn't exceed total file size." if chunk_size && chunk_size > file_size - raise ArgumentError, ':on_checksum_validated must be callable' + chunk_size || [(file_size.to_f / MAX_PARTS).ceil, MIN_CHUNK_SIZE].max.to_i end - def multipart_download - resp = @client.head_object(@params.merge(part_number: 1)) + def compute_mode(file_size, total_parts, etag, opts) + chunk_size = compute_chunk(opts[:chunk_size], file_size) + part_size = (file_size.to_f / total_parts).ceil + + resolve_temp_path(opts) + if chunk_size < part_size + multithreaded_get_by_ranges(file_size, etag, opts) + else + multithreaded_get_by_parts(total_parts, file_size, etag, opts) + end + end + + def extract_range(value) + value.match(%r{bytes (?\d+-\d+)/\d+})[:range] + end + + def multipart_download(opts) + resp = @client.head_object(head_opts(opts[:params].merge(part_number: 1))) count = resp.parts_count if count.nil? || count <= 1 if resp.content_length <= MIN_CHUNK_SIZE - single_request + single_request(opts) else - multithreaded_get_by_ranges(resp.content_length, resp.etag) + resolve_temp_path(opts) + multithreaded_get_by_ranges(resp.content_length, resp.etag, opts) end else # covers cases when given object is not uploaded via UploadPart API - resp = @client.head_object(@params) # partNumber is an option + resp = @client.head_object(head_opts(opts[:params])) # partNumber is an option if resp.content_length <= MIN_CHUNK_SIZE - single_request + single_request(opts) else - compute_mode(resp.content_length, count, resp.etag) + compute_mode(resp.content_length, count, resp.etag, opts) end end end - def compute_mode(file_size, count, etag) - chunk_size = compute_chunk(file_size) - part_size = (file_size.to_f / count).ceil - if chunk_size < part_size - multithreaded_get_by_ranges(file_size, etag) - else - multithreaded_get_by_parts(count, file_size, etag) + def multithreaded_get_by_parts(total_parts, file_size, etag, opts) + parts = (1..total_parts).map do |part| + params = get_opts(opts[:params].merge(part_number: part, if_match: etag)) + Part.new(part_number: part, params: params) end + download_with_executor(PartList.new(parts), file_size, opts) end - def compute_chunk(file_size) - raise ArgumentError, ":chunk_size shouldn't exceed total file size." if @chunk_size && @chunk_size > file_size - - @chunk_size || [(file_size.to_f / MAX_PARTS).ceil, MIN_CHUNK_SIZE].max.to_i - end - - def multithreaded_get_by_ranges(file_size, etag) + def multithreaded_get_by_ranges(file_size, etag, opts) offset = 0 - default_chunk_size = compute_chunk(file_size) + default_chunk_size = compute_chunk(opts[:chunk_size], file_size) chunks = [] part_number = 1 # parts start at 1 while offset < file_size progress = offset + default_chunk_size progress = file_size if progress > file_size - params = @params.merge(range: "bytes=#{offset}-#{progress - 1}", if_match: etag) + params = get_opts(opts[:params].merge(range: "bytes=#{offset}-#{progress - 1}", if_match: etag)) chunks << Part.new(part_number: part_number, size: (progress - offset), params: params) part_number += 1 offset = progress end - download_in_threads(PartList.new(chunks), file_size) - end - - def multithreaded_get_by_parts(n_parts, total_size, etag) - parts = (1..n_parts).map do |part| - Part.new(part_number: part, params: @params.merge(part_number: part, if_match: etag)) - end - download_in_threads(PartList.new(parts), total_size) - end - - def download_in_threads(pending, total_size) - threads = [] - progress = MultipartProgress.new(pending, total_size, @progress_callback) if @progress_callback - unless [File, Tempfile].include?(@destination.class) - @temp_path = "#{@destination}.s3tmp.#{SecureRandom.alphanumeric(8)}" - end - @thread_count.times do - thread = Thread.new do - begin - while (part = pending.shift) - if progress - part.params[:on_chunk_received] = - proc do |_chunk, bytes, total| - progress.call(part.part_number, bytes, total) - end - end - resp = @client.get_object(part.params) - range = extract_range(resp.content_range) - validate_range(range, part.params[:range]) if part.params[:range] - write(resp.body, range) - if @on_checksum_validated && resp.checksum_validated - @on_checksum_validated.call(resp.checksum_validated, resp) - end - end - nil - rescue StandardError => e - pending.clear! # keep other threads from downloading other parts - raise e - end - end - threads << thread - end - threads.map(&:value).compact + download_with_executor(PartList.new(chunks), file_size, opts) end - def extract_range(value) - value.match(%r{bytes (?\d+-\d+)/\d+})[:range] + def range_request(opts) + resp = @client.head_object(head_opts(opts[:params])) + resolve_temp_path(opts) + multithreaded_get_by_ranges(resp.content_length, resp.etag, opts) end - def validate_range(actual, expected) - return if actual == expected.match(/bytes=(?\d+-\d+)/)[:range] - - raise MultipartDownloadError, "multipart download failed: expected range of #{expected} but got #{actual}" - end + def resolve_temp_path(opts) + return if [File, Tempfile].include?(opts[:destination].class) - def write(body, range) - path = @temp_path || @destination - File.write(path, body.read, range.split('-').first.to_i) + opts[:temp_path] ||= "#{opts[:destination]}.s3tmp.#{SecureRandom.alphanumeric(8)}" end - def single_request - params = @params.merge(response_target: @destination) - params[:on_chunk_received] = single_part_progress if @progress_callback + def single_request(opts) + params = get_opts(opts[:params]).merge(response_target: opts[:destination]) + params[:on_chunk_received] = single_part_progress(opts) if opts[:progress_callback] resp = @client.get_object(params) - return resp unless @on_checksum_validated + return resp unless opts[:on_checksum_validated] - @on_checksum_validated.call(resp.checksum_validated, resp) if resp.checksum_validated + opts[:on_checksum_validated].call(resp.checksum_validated, resp) if resp.checksum_validated resp end - def single_part_progress + def single_part_progress(opts) proc do |_chunk, bytes_read, total_size| - @progress_callback.call([bytes_read], [total_size], total_size) + opts[:progress_callback].call([bytes_read], [total_size], total_size) end end + def update_progress(progress, part) + return unless progress.progress_callback + + part.params[:on_chunk_received] = + proc do |_chunk, bytes, total| + progress.call(part.part_number, bytes, total) + end + end + + def execute_checksum_callback(resp, opts) + return unless opts[:on_checksum_validated] && resp.checksum_validated + + opts[:on_checksum_validated].call(resp.checksum_validated, resp) + end + + def validate_destination!(destination) + valid_types = [String, Pathname, File, Tempfile] + return if valid_types.include?(destination.class) + + raise ArgumentError, "Invalid destination, expected #{valid_types.join(', ')} but got: #{destination.class}" + end + + def validate_opts!(opts) + if opts[:on_checksum_validated] && !opts[:on_checksum_validated].respond_to?(:call) + raise ArgumentError, ':on_checksum_validated must be callable' + end + + valid_modes = %w[auto get_range single_request] + unless valid_modes.include?(opts[:mode]) + msg = "Invalid mode #{opts[:mode]} provided, :mode should be single_request, get_range or auto" + raise ArgumentError, msg + end + + if opts[:mode] == 'get_range' && opts[:chunk_size].nil? + raise ArgumentError, 'In get_range mode, :chunk_size must be provided' + end + + if opts[:chunk_size] && opts[:chunk_size] <= 0 + raise ArgumentError, ':chunk_size must be positive' + end + end + + def validate_range(actual, expected) + return if actual == expected.match(/bytes=(?\d+-\d+)/)[:range] + + raise MultipartDownloadError, "multipart download failed: expected range of #{expected} but got #{actual}" + end + + def write(body, range, opts) + path = opts[:temp_path] || opts[:destination] + File.write(path, body.read, range.split('-').first.to_i) + end + # @api private class Part < Struct.new(:part_number, :size, :params) include Aws::Structure @@ -225,6 +286,8 @@ def initialize(parts, total_size, progress_callback) @progress_callback = progress_callback end + attr_reader :progress_callback + def call(part_number, bytes_received, total) # part numbers start at 1 @bytes_received[part_number - 1] = bytes_received diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/file_uploader.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/file_uploader.rb index 587066551ea..62dbb07f8c3 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/file_uploader.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/file_uploader.rb @@ -13,8 +13,8 @@ class FileUploader # @option options [Client] :client # @option options [Integer] :multipart_threshold (104857600) def initialize(options = {}) - @options = options @client = options[:client] || Client.new + @executor = options[:executor] @multipart_threshold = options[:multipart_threshold] || DEFAULT_MULTIPART_THRESHOLD end @@ -36,11 +36,9 @@ def initialize(options = {}) # @return [void] def upload(source, options = {}) Aws::Plugins::UserAgent.metric('S3_TRANSFER') do - if File.size(source) >= multipart_threshold - MultipartFileUploader.new(@options).upload(source, options) + if File.size(source) >= @multipart_threshold + MultipartFileUploader.new(client: @client, executor: @executor).upload(source, options) else - # remove multipart parameters not supported by put_object - options.delete(:thread_count) put_object(source, options) end end @@ -48,9 +46,9 @@ def upload(source, options = {}) private - def open_file(source) - if String === source || Pathname === source - File.open(source, 'rb') { |file| yield(file) } + def open_file(source, &block) + if source.is_a?(String) || source.is_a?(Pathname) + File.open(source, 'rb', &block) else yield(source) end diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_file_uploader.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_file_uploader.rb index bcc05f1fc9e..e1f05d6ee16 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_file_uploader.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_file_uploader.rb @@ -7,10 +7,8 @@ module Aws module S3 # @api private class MultipartFileUploader - MIN_PART_SIZE = 5 * 1024 * 1024 # 5MB MAX_PARTS = 10_000 - DEFAULT_THREAD_COUNT = 10 CREATE_OPTIONS = Set.new(Client.api.operation(:create_multipart_upload).input.shape.member_names) COMPLETE_OPTIONS = Set.new(Client.api.operation(:complete_multipart_upload).input.shape.member_names) UPLOAD_PART_OPTIONS = Set.new(Client.api.operation(:upload_part).input.shape.member_names) @@ -21,10 +19,9 @@ class MultipartFileUploader ) # @option options [Client] :client - # @option options [Integer] :thread_count (DEFAULT_THREAD_COUNT) def initialize(options = {}) @client = options[:client] || Client.new - @thread_count = options[:thread_count] || DEFAULT_THREAD_COUNT + @executor = options[:executor] end # @return [Client] @@ -38,11 +35,12 @@ def initialize(options = {}) # It will be invoked with [bytes_read], [total_sizes] # @return [Seahorse::Client::Response] - the CompleteMultipartUploadResponse def upload(source, options = {}) - raise ArgumentError, 'unable to multipart upload files smaller than 5MB' if File.size(source) < MIN_PART_SIZE + file_size = File.size(source) + raise ArgumentError, 'unable to multipart upload files smaller than 5MB' if file_size < MIN_PART_SIZE upload_id = initiate_upload(options) - parts = upload_parts(upload_id, source, options) - complete_upload(upload_id, parts, source, options) + parts = upload_parts(upload_id, source, file_size, options) + complete_upload(upload_id, parts, file_size, options) end private @@ -51,22 +49,22 @@ def initiate_upload(options) @client.create_multipart_upload(create_opts(options)).upload_id end - def complete_upload(upload_id, parts, source, options) + def complete_upload(upload_id, parts, file_size, options) @client.complete_multipart_upload( **complete_opts(options).merge( upload_id: upload_id, multipart_upload: { parts: parts }, - mpu_object_size: File.size(source) + mpu_object_size: file_size ) ) rescue StandardError => e abort_upload(upload_id, options, [e]) end - def upload_parts(upload_id, source, options) + def upload_parts(upload_id, source, file_size, options) completed = PartList.new - pending = PartList.new(compute_parts(upload_id, source, options)) - errors = upload_in_threads(pending, completed, options) + pending = PartList.new(compute_parts(upload_id, source, file_size, options)) + errors = upload_with_executor(pending, completed, options) if errors.empty? completed.to_a.sort_by { |part| part[:part_number] } else @@ -86,17 +84,20 @@ def abort_upload(upload_id, options, errors) raise MultipartUploadError.new(msg, errors + [e]) end - def compute_parts(upload_id, source, options) - size = File.size(source) - default_part_size = compute_default_part_size(size) + def compute_parts(upload_id, source, file_size, options) + default_part_size = compute_default_part_size(file_size) offset = 0 part_number = 1 parts = [] - while offset < size + while offset < file_size parts << upload_part_opts(options).merge( upload_id: upload_id, part_number: part_number, - body: FilePart.new(source: source, offset: offset, size: part_size(size, default_part_size, offset)) + body: FilePart.new( + source: source, + offset: offset, + size: part_size(file_size, default_part_size, offset) + ) ) part_number += 1 offset += default_part_size @@ -115,17 +116,13 @@ def has_checksum_key?(keys) def create_opts(options) opts = { checksum_algorithm: Aws::Plugins::ChecksumAlgorithm::DEFAULT_CHECKSUM } opts[:checksum_type] = 'FULL_OBJECT' if has_checksum_key?(options.keys) - CREATE_OPTIONS.each_with_object(opts) do |key, hash| - hash[key] = options[key] if options.key?(key) - end + CREATE_OPTIONS.each_with_object(opts) { |k, h| h[k] = options[k] if options.key?(k) } end def complete_opts(options) opts = {} opts[:checksum_type] = 'FULL_OBJECT' if has_checksum_key?(options.keys) - COMPLETE_OPTIONS.each_with_object(opts) do |key, hash| - hash[key] = options[key] if options.key?(key) - end + COMPLETE_OPTIONS.each_with_object(opts) { |k, h| h[k] = options[k] if options.key?(k) } end def upload_part_opts(options) @@ -135,43 +132,40 @@ def upload_part_opts(options) end end - def upload_in_threads(pending, completed, options) - threads = [] - if (callback = options[:progress_callback]) - progress = MultipartProgress.new(pending, callback) - end - options.fetch(:thread_count, @thread_count).times do - thread = Thread.new do - begin - while (part = pending.shift) - if progress - part[:on_chunk_sent] = - proc do |_chunk, bytes, _total| - progress.call(part[:part_number], bytes) - end - end - resp = @client.upload_part(part) - part[:body].close - completed_part = { etag: resp.etag, part_number: part[:part_number] } - algorithm = resp.context.params[:checksum_algorithm] - k = "checksum_#{algorithm.downcase}".to_sym - completed_part[k] = resp.send(k) - completed.push(completed_part) - end - nil - rescue StandardError => e - # keep other threads from uploading other parts - pending.clear! - e - end + def upload_with_executor(pending, completed, options) + upload_attempts = 0 + completion_queue = Queue.new + abort_upload = false + errors = [] + progress = MultipartProgress.new(pending, options[:progress_callback]) + + while (part = pending.shift) + break if abort_upload + + upload_attempts += 1 + @executor.post(part) do |p| + update_progress(progress, p) + resp = @client.upload_part(p) + p[:body].close + completed_part = { etag: resp.etag, part_number: p[:part_number] } + algorithm = resp.context.params[:checksum_algorithm].downcase + k = "checksum_#{algorithm}".to_sym + completed_part[k] = resp.send(k) + completed.push(completed_part) + rescue StandardError => e + abort_upload = true + errors << e + ensure + completion_queue << :done end - threads << thread end - threads.map(&:value).compact + + upload_attempts.times { completion_queue.pop } + errors end - def compute_default_part_size(source_size) - [(source_size.to_f / MAX_PARTS).ceil, MIN_PART_SIZE].max.to_i + def compute_default_part_size(file_size) + [(file_size.to_f / MAX_PARTS).ceil, MIN_PART_SIZE].max.to_i end def part_size(total_size, part_size, offset) @@ -182,6 +176,15 @@ def part_size(total_size, part_size, offset) end end + def update_progress(progress, part) + return unless progress.progress_callback + + part[:on_chunk_sent] = + proc do |_chunk, bytes, _total| + progress.call(part[:part_number], bytes) + end + end + # @api private class PartList def initialize(parts = []) @@ -222,6 +225,8 @@ def initialize(parts, progress_callback) @progress_callback = progress_callback end + attr_reader :progress_callback + def call(part_number, bytes_read) # part numbers start at 1 @bytes_sent[part_number - 1] = bytes_read diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_stream_uploader.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_stream_uploader.rb index 60a298c720b..ae6f75a47a1 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_stream_uploader.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_stream_uploader.rb @@ -11,7 +11,6 @@ module S3 class MultipartStreamUploader DEFAULT_PART_SIZE = 5 * 1024 * 1024 # 5MB - DEFAULT_THREAD_COUNT = 10 CREATE_OPTIONS = Set.new(Client.api.operation(:create_multipart_upload).input.shape.member_names) UPLOAD_PART_OPTIONS = Set.new(Client.api.operation(:upload_part).input.shape.member_names) COMPLETE_UPLOAD_OPTIONS = Set.new(Client.api.operation(:complete_multipart_upload).input.shape.member_names) @@ -19,9 +18,9 @@ class MultipartStreamUploader # @option options [Client] :client def initialize(options = {}) @client = options[:client] || Client.new + @executor = options[:executor] @tempfile = options[:tempfile] @part_size = options[:part_size] || DEFAULT_PART_SIZE - @thread_count = options[:thread_count] || DEFAULT_THREAD_COUNT end # @return [Client] @@ -29,7 +28,6 @@ def initialize(options = {}) # @option options [required,String] :bucket # @option options [required,String] :key - # @option options [Integer] :thread_count (DEFAULT_THREAD_COUNT) # @return [Seahorse::Client::Response] - the CompleteMultipartUploadResponse def upload(options = {}, &block) Aws::Plugins::UserAgent.metric('S3_TRANSFER') do @@ -54,28 +52,30 @@ def complete_upload(upload_id, parts, options) end def upload_parts(upload_id, options, &block) - completed = Queue.new - thread_errors = [] - errors = begin + completed_parts = Queue.new + errors = [] + + begin IO.pipe do |read_pipe, write_pipe| - threads = upload_in_threads( - read_pipe, - completed, - upload_part_opts(options).merge(upload_id: upload_id), - thread_errors - ) - begin - block.call(write_pipe) - ensure - # Ensure the pipe is closed to avoid https://github.com/jruby/jruby/issues/6111 - write_pipe.close + upload_thread = Thread.new do + upload_with_executor( + read_pipe, + completed_parts, + errors, + upload_part_opts(options).merge(upload_id: upload_id) + ) end - threads.map(&:value).compact + + block.call(write_pipe) + ensure + # Ensure the pipe is closed to avoid https://github.com/jruby/jruby/issues/6111 + write_pipe.close + upload_thread.join end rescue StandardError => e - thread_errors + [e] + errors << e end - return ordered_parts(completed) if errors.empty? + return ordered_parts(completed_parts) if errors.empty? abort_upload(upload_id, options, errors) end @@ -128,37 +128,34 @@ def read_to_part_body(read_pipe) end end - def upload_in_threads(read_pipe, completed, options, thread_errors) - mutex = Mutex.new + def upload_with_executor(read_pipe, completed, errors, options) + completion_queue = Queue.new + queued_parts = 0 part_number = 0 - options.fetch(:thread_count, @thread_count).times.map do - thread = Thread.new do - loop do - body, thread_part_number = mutex.synchronize do - [read_to_part_body(read_pipe), part_number += 1] - end - break unless body || thread_part_number == 1 - - begin - part = options.merge(body: body, part_number: thread_part_number) - resp = @client.upload_part(part) - completed_part = create_completed_part(resp, part) - completed.push(completed_part) - ensure - clear_body(body) - end - end - nil + mutex = Mutex.new + loop do + part_body, current_part_num = mutex.synchronize do + [read_to_part_body(read_pipe), part_number += 1] + end + break unless part_body || current_part_num == 1 + + queued_parts += 1 + @executor.post(part_body, current_part_num, options) do |body, num, opts| + part = opts.merge(body: body, part_number: num) + resp = @client.upload_part(part) + completed_part = create_completed_part(resp, part) + completed.push(completed_part) rescue StandardError => e - # keep other threads from uploading other parts mutex.synchronize do - thread_errors.push(e) + errors.push(e) read_pipe.close_read unless read_pipe.closed? end - e + ensure + clear_body(body) + completion_queue << :done end - thread end + queued_parts.times { completion_queue.pop } end def create_completed_part(resp, part) diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/transfer_manager.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/transfer_manager.rb index 03bd249d5e2..9035e9699ac 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/transfer_manager.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/transfer_manager.rb @@ -2,27 +2,74 @@ module Aws module S3 - # A high-level S3 transfer utility that provides enhanced upload and download - # capabilities with automatic multipart handling, progress tracking, and - # handling of large files. The following features are supported: + # A high-level S3 transfer utility that provides enhanced upload and download capabilities with automatic + # multipart handling, progress tracking, and handling of large files. The following features are supported: # # * upload a file with multipart upload # * upload a stream with multipart upload # * download a S3 object with multipart download # * track transfer progress by using progress listener # + # ## Executor Management + # TransferManager uses executors to handle concurrent operations during multipart transfers. You can control + # concurrency behavior by providing a custom executor or relying on the default executor management. + # + # ### Default Behavior + # When no `:executor` is provided, TransferManager creates a new DefaultExecutor for each individual + # operation (`download_file`, `upload_file`, etc.) and automatically shuts it down when that operation completes. + # Each operation gets its own isolated thread pool with the specified `:thread_count` (default 10 threads). + # + # ### Custom Executor + # You can provide your own executor (e.g., `Concurrent::ThreadPoolExecutor`) for fine-grained control over thread + # pools and resource management. When using a custom executor, you are responsible for shutting it down + # when finished. The executor may be reused across multiple TransferManager operations. + # + # Custom executors must implement the same interface as DefaultExecutor. + # + # **Required methods:** + # + # * `post(*args, &block)` - Execute a task with given arguments and block + # * `kill` - Immediately terminate all running tasks + # + # **Optional methods:** + # + # * `shutdown(timeout = nil)` - Gracefully shutdown the executor with optional timeout + # + # @example Using default executor (automatic creation and shutdown) + # tm = TransferManager.new # No executor provided + # # DefaultExecutor created, used, and shutdown automatically + # tm.download_file('/path/to/file', bucket: 'bucket', key: 'key') + # + # @example Using custom executor (manual shutdown required) + # require 'concurrent-ruby' + # + # executor = Concurrent::ThreadPoolExecutor.new(max_threads: 5) + # tm = TransferManager.new(executor: executor) + # tm.download_file('/path/to/file1', bucket: 'bucket', key: 'key1') + # executor.shutdown # You must shutdown custom executors + # class TransferManager + # @param [Hash] options # @option options [S3::Client] :client (S3::Client.new) # The S3 client to use for {TransferManager} operations. If not provided, a new default client # will be created automatically. + # @option options [Object] :executor + # The executor to use for multipart operations. Must implement the same interface as {DefaultExecutor}. + # If not provided, a new {DefaultExecutor} will be created automatically for each operation and + # shutdown after completion. When provided a custom executor, it will be reused across operations, and + # you are responsible for shutting it down when finished. def initialize(options = {}) - @client = options.delete(:client) || Client.new + @client = options[:client] || Client.new + @executor = options[:executor] end # @return [S3::Client] attr_reader :client + # @return [Object] + attr_reader :executor + # Downloads a file in S3 to a path on disk. # # # small files (< 5MB) are downloaded in a single API call @@ -74,10 +121,7 @@ def initialize(options = {}) # @option options [Integer] :chunk_size required in `"get_range"` mode. # # @option options [Integer] :thread_count (10) Customize threads used in the multipart download. - # - # @option options [String] :version_id The object version id used to retrieve the object. - # - # @see https://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectVersioning.html ObjectVersioning + # Only used when no custom executor is provided (creates {DefaultExecutor} with given thread count). # # @option options [String] :checksum_mode ("ENABLED") # When `"ENABLED"` and the object has a stored checksum, it will be used to validate the download and will @@ -102,8 +146,11 @@ def initialize(options = {}) # @see Client#get_object # @see Client#head_object def download_file(destination, bucket:, key:, **options) - downloader = FileDownloader.new(client: @client) - downloader.download(destination, options.merge(bucket: bucket, key: key)) + download_opts = options.merge(bucket: bucket, key: key) + executor = @executor || DefaultExecutor.new(max_threads: download_opts.delete(:thread_count)) + downloader = FileDownloader.new(client: @client, executor: executor) + downloader.download(destination, download_opts) + executor.shutdown unless @executor true end @@ -139,7 +186,7 @@ def download_file(destination, bucket:, key:, **options) # A file on the local file system that will be uploaded. This can either be a `String` or `Pathname` to the # file, an open `File` object, or an open `Tempfile` object. If you pass an open `File` or `Tempfile` object, # then you are responsible for closing it after the upload completes. When using an open Tempfile, rewind it - # before uploading or else the object will be empty. + # before uploading or else the object will be empty. # # @param [String] bucket # The name of the S3 bucket to upload to. @@ -156,15 +203,14 @@ def download_file(destination, bucket:, key:, **options) # Files larger han or equal to `:multipart_threshold` are uploaded using the S3 multipart upload APIs. # Default threshold is `100MB`. # - # @option options [Integer] :thread_count (10) - # The number of parallel multipart uploads. This option is not used if the file is smaller than - # `:multipart_threshold`. + # @option options [Integer] :thread_count (10) Customize threads used in the multipart upload. + # Only used when no custom executor is provided (creates {DefaultExecutor} with the given thread count). # # @option options [Proc] :progress_callback (nil) # A Proc that will be called when each chunk of the upload is sent. # It will be invoked with `[bytes_read]` and `[total_sizes]`. # - # @raise [MultipartUploadError] If an file is being uploaded in parts, and the upload can not be completed, + # @raise [MultipartUploadError] If a file is being uploaded in parts, and the upload can not be completed, # then the upload is aborted and this error is raised. The raised error has a `#errors` method that # returns the failures that caused the upload to be aborted. # @@ -175,13 +221,16 @@ def download_file(destination, bucket:, key:, **options) # @see Client#complete_multipart_upload # @see Client#upload_part def upload_file(source, bucket:, key:, **options) - uploading_options = options.dup + upload_opts = options.merge(bucket: bucket, key: key) + executor = @executor || DefaultExecutor.new(max_threads: upload_opts.delete(:thread_count)) uploader = FileUploader.new( - multipart_threshold: uploading_options.delete(:multipart_threshold), - client: @client + multipart_threshold: upload_opts.delete(:multipart_threshold), + client: @client, + executor: executor ) - response = uploader.upload(source, uploading_options.merge(bucket: bucket, key: key)) + response = uploader.upload(source, upload_opts) yield response if block_given? + executor.shutdown unless @executor true end @@ -217,7 +266,8 @@ def upload_file(source, bucket:, key:, **options) # {Client#upload_part} can be provided. # # @option options [Integer] :thread_count (10) - # The number of parallel multipart uploads. + # The number of parallel multipart uploads. Only used when no custom executor is provided (creates + # {DefaultExecutor} with the given thread count). An additional thread is used internally for task coordination. # # @option options [Boolean] :tempfile (false) # Normally read data is stored in memory when building the parts in order to complete the underlying @@ -237,14 +287,16 @@ def upload_file(source, bucket:, key:, **options) # @see Client#complete_multipart_upload # @see Client#upload_part def upload_stream(bucket:, key:, **options, &block) - uploading_options = options.dup + upload_opts = options.merge(bucket: bucket, key: key) + executor = @executor || DefaultExecutor.new(max_threads: upload_opts.delete(:thread_count)) uploader = MultipartStreamUploader.new( client: @client, - thread_count: uploading_options.delete(:thread_count), - tempfile: uploading_options.delete(:tempfile), - part_size: uploading_options.delete(:part_size) + executor: executor, + tempfile: upload_opts.delete(:tempfile), + part_size: upload_opts.delete(:part_size) ) - uploader.upload(uploading_options.merge(bucket: bucket, key: key), &block) + uploader.upload(upload_opts, &block) + executor.shutdown unless @executor true end end diff --git a/gems/aws-sdk-s3/spec/default_executor_spec.rb b/gems/aws-sdk-s3/spec/default_executor_spec.rb new file mode 100644 index 00000000000..ffb6974d096 --- /dev/null +++ b/gems/aws-sdk-s3/spec/default_executor_spec.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +require_relative 'spec_helper' + +module Aws + module S3 + describe DefaultExecutor do + let(:subject) { DefaultExecutor.new } + + describe '#post' do + it 'executes a block with arguments' do + queue = Queue.new + subject.post('hello') { |arg| queue << arg } + expect(queue.pop).to eq('hello') + end + + it 'returns true when a task is submitted' do + expect(subject.post('hello') { |_arg| }).to be(true) + end + + it 'raises when executor is shutdown' do + subject.shutdown + expect { subject.post }.to raise_error(RuntimeError) + end + end + + describe '#shutdown' do + it 'waits for running tasks to be complete' do + result = nil + subject.post { result = true } + expect(subject.shutdown).to be(true) + expect(result).to be(true) + end + + it 'kills threads after timeout' do + started = Queue.new + counter = 0 + subject.post do + counter += 1 + started << 'work started' + sleep 1 + counter += 1 + end + started.pop + expect(subject.shutdown(0.01)).to be(true) + expect(counter).to eq(1) + end + end + + describe '#kill' do + it 'stops all threads immediately and returns true' do + started = Queue.new + counter = 0 + subject.post do + counter += 1 + started << 'work started' + sleep 1 + counter += 1 + end + started.pop + result = subject.kill + expect(result).to be(true) + expect(counter).to eq(1) + end + end + end + end +end diff --git a/gems/aws-sdk-s3/spec/file_downloader_spec.rb b/gems/aws-sdk-s3/spec/file_downloader_spec.rb index 0f3dd8fb2cf..a596c1fa6ba 100644 --- a/gems/aws-sdk-s3/spec/file_downloader_spec.rb +++ b/gems/aws-sdk-s3/spec/file_downloader_spec.rb @@ -7,7 +7,7 @@ module Aws module S3 describe FileDownloader do let(:client) { S3::Client.new(stub_responses: true) } - let(:subject) { FileDownloader.new(client: client) } + let(:subject) { FileDownloader.new(client: client, executor: DefaultExecutor.new) } let(:tmpdir) { Dir.tmpdir } describe '#initialize' do @@ -198,7 +198,6 @@ module S3 it 'raises when checksum validation fails on multipart object' do client.stub_responses(:get_object, { body: 'body', checksum_sha1: 'invalid' }) - expect(Thread).to receive(:new).and_yield.and_return(double(value: nil)) expect { subject.download(path, parts_params) }.to raise_error(Aws::Errors::ChecksumError) end @@ -208,7 +207,6 @@ module S3 expect(ctx.params[:if_match]).to eq('test-etag') 'PreconditionFailed' }) - expect(Thread).to receive(:new).and_yield.and_return(double(value: nil)) expect { subject.download(path, range_params.merge(chunk_size: one_meg, mode: 'get_range')) } .to raise_error(Aws::S3::Errors::PreconditionFailed) end @@ -219,8 +217,6 @@ module S3 expect(ctx.params[:if_match]).to eq('test-etag') 'PreconditionFailed' }) - - expect(Thread).to receive(:new).and_yield.and_return(double(value: nil)) expect { subject.download(path, parts_params) }.to raise_error(Aws::S3::Errors::PreconditionFailed) end @@ -246,7 +242,6 @@ module S3 it 'raises when range validation fails' do client.stub_responses(:get_object, { body: 'body', content_range: 'bytes 0-3/4' }) - expect(Thread).to receive(:new).and_yield.and_return(double(value: nil)) expect { subject.download(path, range_params.merge(mode: 'get_range', chunk_size: one_meg)) } .to raise_error(Aws::S3::MultipartDownloadError) end @@ -263,7 +258,6 @@ module S3 responses[context.params[:range]] }) - expect(Thread).to receive(:new).and_yield.and_return(double(value: nil)) expect { subject.download(path, range_params.merge(chunk_size: 5 * one_meg, mode: 'get_range')) } .to raise_error(Aws::S3::MultipartDownloadError) expect(File.exist?(path)).to be(true) diff --git a/gems/aws-sdk-s3/spec/file_uploader_spec.rb b/gems/aws-sdk-s3/spec/file_uploader_spec.rb index 64dbaf7cdd2..dc3ab94b052 100644 --- a/gems/aws-sdk-s3/spec/file_uploader_spec.rb +++ b/gems/aws-sdk-s3/spec/file_uploader_spec.rb @@ -81,12 +81,6 @@ module S3 subject.upload(ten_meg_file.path, params) end - - it 'does not fail when given :thread_count' do - expect(client).to receive(:put_object).with(params.merge(body: ten_meg_file)) - - subject.upload(ten_meg_file, params.merge(thread_count: 1)) - end end end end diff --git a/gems/aws-sdk-s3/spec/multipart_file_uploader_spec.rb b/gems/aws-sdk-s3/spec/multipart_file_uploader_spec.rb index 82e147986e3..0d2cd2fae4f 100644 --- a/gems/aws-sdk-s3/spec/multipart_file_uploader_spec.rb +++ b/gems/aws-sdk-s3/spec/multipart_file_uploader_spec.rb @@ -7,7 +7,7 @@ module Aws module S3 describe MultipartFileUploader do let(:client) { S3::Client.new(stub_responses: true) } - let(:subject) { MultipartFileUploader.new(client: client) } + let(:subject) { MultipartFileUploader.new(client: client, executor: DefaultExecutor.new) } let(:params) { { bucket: 'bucket', key: 'key' } } describe '#initialize' do @@ -85,7 +85,6 @@ module S3 end it 'reports progress for multipart uploads' do - allow(Thread).to receive(:new).and_yield.and_return(double(value: nil)) client.stub_responses(:create_multipart_upload, upload_id: 'id') client.stub_responses(:complete_multipart_upload) expect(client).to receive(:upload_part).exactly(24).times do |args| @@ -127,10 +126,6 @@ module S3 end it 'reports when it is unable to abort a failed multipart upload' do - allow(Thread).to receive(:new) do |_, &block| - double(value: block.call) - end - client.stub_responses( :upload_part, [ diff --git a/gems/aws-sdk-s3/spec/multipart_stream_uploader_spec.rb b/gems/aws-sdk-s3/spec/multipart_stream_uploader_spec.rb index e7c2860f5a7..6262a42b560 100644 --- a/gems/aws-sdk-s3/spec/multipart_stream_uploader_spec.rb +++ b/gems/aws-sdk-s3/spec/multipart_stream_uploader_spec.rb @@ -7,7 +7,7 @@ module Aws module S3 describe MultipartStreamUploader do let(:client) { S3::Client.new(stub_responses: true) } - let(:subject) { MultipartStreamUploader.new(client: client) } + let(:subject) { MultipartStreamUploader.new(client: client, executor: DefaultExecutor.new) } let(:params) { { bucket: 'bucket', key: 'key' } } let(:one_mb) { '.' * 1024 * 1024 } let(:seventeen_mb) { one_mb * 17 } @@ -50,7 +50,6 @@ module S3 } ) expect(client).to receive(:complete_multipart_upload).with(expected_params).once - subject.upload(params.merge(content_type: 'text/plain')) { |write_stream| write_stream << seventeen_mb } end @@ -155,7 +154,7 @@ module S3 end context 'when tempfile is true' do - let(:subject) { MultipartStreamUploader.new(client: client, tempfile: true) } + let(:subject) { MultipartStreamUploader.new(client: client, tempfile: true, executor: DefaultExecutor.new) } it 'uses multipart APIs' do client.stub_responses(:create_multipart_upload, upload_id: 'id') diff --git a/gems/aws-sdk-s3/spec/object/upload_stream_spec.rb b/gems/aws-sdk-s3/spec/object/upload_stream_spec.rb index 269b91746eb..d75b58a7515 100644 --- a/gems/aws-sdk-s3/spec/object/upload_stream_spec.rb +++ b/gems/aws-sdk-s3/spec/object/upload_stream_spec.rb @@ -27,9 +27,9 @@ module S3 it 'respects the thread_count option' do custom_thread_count = 20 - expect(Thread).to receive(:new).exactly(custom_thread_count).times.and_return(double(value: nil)) client.stub_responses(:create_multipart_upload, upload_id: 'id') client.stub_responses(:complete_multipart_upload) + expect(DefaultExecutor).to receive(:new).with(max_threads: custom_thread_count).and_call_original subject.upload_stream(thread_count: custom_thread_count) { |_write_stream| } end