From 828f0788009f66ca584836fcd1702b1de22e1f75 Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Mon, 17 Mar 2025 19:15:53 +0000 Subject: [PATCH 1/4] fast_import: put job status to s3 --- compute_tools/src/bin/fast_import.rs | 45 ++++++++--- test_runner/regress/test_import_pgdata.py | 95 ++++++++++++++++++++--- 2 files changed, 121 insertions(+), 19 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 47558be7a028..7b275a2b2070 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -31,6 +31,7 @@ use camino::{Utf8Path, Utf8PathBuf}; use clap::{Parser, Subcommand}; use compute_tools::extension_server::{PostgresMajorVersion, get_pg_version}; use nix::unistd::Pid; +use std::ops::Not; use tracing::{Instrument, error, info, info_span, warn}; use utils::fs_ext::is_directory_empty; @@ -437,7 +438,7 @@ async fn run_dump_restore( #[allow(clippy::too_many_arguments)] async fn cmd_pgdata( - s3_client: Option, + s3_client: Option<&aws_sdk_s3::Client>, kms_client: Option, maybe_s3_prefix: Option, maybe_spec: Option, @@ -506,14 +507,14 @@ async fn cmd_pgdata( if let Some(s3_prefix) = maybe_s3_prefix { info!("upload pgdata"); aws_s3_sync::upload_dir_recursive( - s3_client.as_ref().unwrap(), + s3_client.unwrap(), Utf8Path::new(&pgdata_dir), &s3_prefix.append("/pgdata/"), ) .await .context("sync dump directory to destination")?; - info!("write status"); + info!("write pgdata status to s3"); { let status_dir = workdir.join("status"); std::fs::create_dir(&status_dir).context("create status directory")?; @@ -644,7 +645,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), } - match args.command { + let res = match args.command { Command::Pgdata { source_connection_string, interactive, @@ -653,20 +654,20 @@ pub(crate) async fn main() -> anyhow::Result<()> { memory_mb, } => { cmd_pgdata( - s3_client, + s3_client.as_ref(), kms_client, - args.s3_prefix, + args.s3_prefix.clone(), spec, source_connection_string, interactive, pg_port, - args.working_directory, + args.working_directory.clone(), args.pg_bin_dir, args.pg_lib_dir, num_cpus, memory_mb, ) - .await?; + .await } Command::DumpRestore { source_connection_string, @@ -677,11 +678,35 @@ pub(crate) async fn main() -> anyhow::Result<()> { spec, source_connection_string, destination_connection_string, - args.working_directory, + args.working_directory.clone(), args.pg_bin_dir, args.pg_lib_dir, ) - .await?; + .await + } + }; + + if let Some(s3_prefix) = args.s3_prefix { + info!("write job status to s3"); + { + let status_dir = args.working_directory.join("status"); + if std::fs::exists(&status_dir)?.not() { + std::fs::create_dir(&status_dir).context("create status directory")?; + } + let status_file = status_dir.join("fast_import"); + let res_obj = if res.is_ok() { + serde_json::json!({"done": true}) + } else { + serde_json::json!({"done": false, "error": res.unwrap_err().to_string()}) + }; + std::fs::write(&status_file, res_obj.to_string()).context("write status file")?; + aws_s3_sync::upload_dir_recursive( + s3_client.as_ref().unwrap(), + &status_dir, + &s3_prefix.append("/status/"), + ) + .await + .context("sync status directory to destination")?; } } diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 71e0d16eddb2..a413ec8f86fa 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -449,6 +449,17 @@ def handler(request: Request) -> Response: fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" pg_port = port_distributor.get_port() fast_import.run_pgdata(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}") + + pgdata_status_obj = mock_s3_client.get_object(Bucket=bucket, Key=f"{key_prefix}/status/pgdata") + pgdata_status = pgdata_status_obj["Body"].read().decode("utf-8") + assert json.loads(pgdata_status) == {"done": True}, f"got status: {pgdata_status}" + + job_status_obj = mock_s3_client.get_object( + Bucket=bucket, Key=f"{key_prefix}/status/fast_import" + ) + job_status = job_status_obj["Body"].read().decode("utf-8") + assert json.loads(job_status) == {"done": True}, f"got status: {job_status}" + vanilla_pg.stop() def validate_vanilla_equivalence(ep): @@ -674,9 +685,11 @@ def encrypt(x: str) -> EncryptResponseTypeDef: ).decode("utf-8"), } - mock_s3_client.create_bucket(Bucket="test-bucket") + bucket = "test-bucket" + key_prefix = "test-prefix" + mock_s3_client.create_bucket(Bucket=bucket) mock_s3_client.put_object( - Bucket="test-bucket", Key="test-prefix/spec.json", Body=json.dumps(spec) + Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec) ) # Run fast_import @@ -688,7 +701,13 @@ def encrypt(x: str) -> EncryptResponseTypeDef: fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" - fast_import.run_dump_restore(s3prefix="s3://test-bucket/test-prefix") + fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}") + + job_status_obj = mock_s3_client.get_object( + Bucket=bucket, Key=f"{key_prefix}/status/fast_import" + ) + job_status = job_status_obj["Body"].read().decode("utf-8") + assert json.loads(job_status) == {"done": True}, f"got status: {job_status}" vanilla_pg.stop() res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;") @@ -696,9 +715,67 @@ def encrypt(x: str) -> EncryptResponseTypeDef: assert res[0][0] == 10 -# TODO: Maybe test with pageserver? -# 1. run whole neon env -# 2. create timeline with some s3 path??? -# 3. run fast_import with s3 prefix -# 4. ??? mock http where pageserver will report progress -# 5. run compute on this timeline and check if data is there +def test_fast_import_restore_to_connstring_error_to_s3( + test_output_dir, + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + fast_import: FastImport, + pg_distrib_dir: Path, + pg_version: PgVersion, + mock_s3_server: MockS3Server, + mock_kms: KMSClient, + mock_s3_client: S3Client, +): + # Prepare KMS and S3 + key_response = mock_kms.create_key( + Description="Test key", + KeyUsage="ENCRYPT_DECRYPT", + Origin="AWS_KMS", + ) + key_id = key_response["KeyMetadata"]["KeyId"] + + def encrypt(x: str) -> EncryptResponseTypeDef: + return mock_kms.encrypt(KeyId=key_id, Plaintext=x) + + # Start source postgres and ingest data + vanilla_pg.start() + vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);") + + # Encrypt connstrings and put spec into S3 + source_connstring_encrypted = encrypt(vanilla_pg.connstr()) + destination_connstring_encrypted = encrypt("postgres://random:connection@string:5432/neondb") + spec = { + "encryption_secret": {"KMS": {"key_id": key_id}}, + "source_connstring_ciphertext_base64": base64.b64encode( + source_connstring_encrypted["CiphertextBlob"] + ).decode("utf-8"), + "destination_connstring_ciphertext_base64": base64.b64encode( + destination_connstring_encrypted["CiphertextBlob"] + ).decode("utf-8"), + } + + bucket = "test-bucket" + key_prefix = "test-prefix" + mock_s3_client.create_bucket(Bucket=bucket) + mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec)) + + # Run fast_import + if fast_import.extra_env is None: + fast_import.extra_env = {} + fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() + fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() + fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() + fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() + fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() + fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" + fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}") + + job_status_obj = mock_s3_client.get_object( + Bucket=bucket, Key=f"{key_prefix}/status/fast_import" + ) + job_status = job_status_obj["Body"].read().decode("utf-8") + assert json.loads(job_status) == { + "done": False, + "error": "pg_restore failed", + }, f"got status: {job_status}" + vanilla_pg.stop() From 73a1231871688b76a263754be82acf1ea34b5afd Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Tue, 18 Mar 2025 19:48:31 +0000 Subject: [PATCH 2/4] Capture error earlier --- compute_tools/src/bin/fast_import.rs | 136 +++++++++++----------- test_runner/fixtures/fast_import.py | 13 +++ test_runner/regress/test_import_pgdata.py | 79 +++++++++---- 3 files changed, 138 insertions(+), 90 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 7b275a2b2070..5844bcb0e577 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -551,13 +551,15 @@ async fn cmd_dumprestore( &key_id, spec.source_connstring_ciphertext_base64, ) - .await?; + .await + .context("decrypt source connection string")?; let dest = if let Some(dest_ciphertext) = spec.destination_connstring_ciphertext_base64 { decode_connstring(kms_client.as_ref().unwrap(), &key_id, dest_ciphertext) - .await? + .await + .context("decrypt destination connection string")? } else { bail!( "destination connection string must be provided in spec for dump_restore command" @@ -610,81 +612,85 @@ pub(crate) async fn main() -> anyhow::Result<()> { (None, None) }; - let spec: Option = if let Some(s3_prefix) = &args.s3_prefix { - let spec_key = s3_prefix.append("/spec.json"); - let object = s3_client - .as_ref() - .unwrap() - .get_object() - .bucket(&spec_key.bucket) - .key(spec_key.key) - .send() - .await - .context("get spec from s3")? - .body - .collect() - .await - .context("download spec body")?; - serde_json::from_slice(&object.into_bytes()).context("parse spec as json")? - } else { - None - }; - - match tokio::fs::create_dir(&args.working_directory).await { - Ok(()) => {} - Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { - if !is_directory_empty(&args.working_directory) + // Capture everything from spec assignment onwards to handle errors + let res = async { + let spec: Option = if let Some(s3_prefix) = &args.s3_prefix { + let spec_key = s3_prefix.append("/spec.json"); + let object = s3_client + .as_ref() + .unwrap() + .get_object() + .bucket(&spec_key.bucket) + .key(spec_key.key) + .send() .await - .context("check if working directory is empty")? - { - bail!("working directory is not empty"); - } else { - // ok + .context("get spec from s3")? + .body + .collect() + .await + .context("download spec body")?; + serde_json::from_slice(&object.into_bytes()).context("parse spec as json")? + } else { + None + }; + + match tokio::fs::create_dir(&args.working_directory).await { + Ok(()) => {} + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + if !is_directory_empty(&args.working_directory) + .await + .context("check if working directory is empty")? + { + bail!("working directory is not empty"); + } else { + // ok + } } + Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), } - Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), - } - let res = match args.command { - Command::Pgdata { - source_connection_string, - interactive, - pg_port, - num_cpus, - memory_mb, - } => { - cmd_pgdata( - s3_client.as_ref(), - kms_client, - args.s3_prefix.clone(), - spec, + match args.command { + Command::Pgdata { source_connection_string, interactive, pg_port, - args.working_directory.clone(), - args.pg_bin_dir, - args.pg_lib_dir, num_cpus, memory_mb, - ) - .await - } - Command::DumpRestore { - source_connection_string, - destination_connection_string, - } => { - cmd_dumprestore( - kms_client, - spec, + } => { + cmd_pgdata( + s3_client.as_ref(), + kms_client, + args.s3_prefix.clone(), + spec, + source_connection_string, + interactive, + pg_port, + args.working_directory.clone(), + args.pg_bin_dir, + args.pg_lib_dir, + num_cpus, + memory_mb, + ) + .await + } + Command::DumpRestore { source_connection_string, destination_connection_string, - args.working_directory.clone(), - args.pg_bin_dir, - args.pg_lib_dir, - ) - .await + } => { + cmd_dumprestore( + kms_client, + spec, + source_connection_string, + destination_connection_string, + args.working_directory.clone(), + args.pg_bin_dir, + args.pg_lib_dir, + ) + .await + } } - }; + } + .await; if let Some(s3_prefix) = args.s3_prefix { info!("write job status to s3"); diff --git a/test_runner/fixtures/fast_import.py b/test_runner/fixtures/fast_import.py index d674be99dec5..d8fb189345b6 100644 --- a/test_runner/fixtures/fast_import.py +++ b/test_runner/fixtures/fast_import.py @@ -12,6 +12,7 @@ from fixtures.log_helper import log from fixtures.neon_cli import AbstractNeonCli from fixtures.pg_version import PgVersion +from fixtures.remote_storage import MockS3Server class FastImport(AbstractNeonCli): @@ -111,6 +112,18 @@ def run( self.cmd = self.raw_cli(args) return self.cmd + def set_aws_creds(self, mock_s3_server: MockS3Server, extra_env: dict[str, str] | None = None): + if self.extra_env is None: + self.extra_env = {} + self.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() + self.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() + self.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() + self.extra_env["AWS_REGION"] = mock_s3_server.region() + self.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() + + if extra_env is not None: + self.extra_env.update(extra_env) + def __enter__(self): return self diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index a413ec8f86fa..9a58ba538600 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -439,14 +439,7 @@ def handler(request: Request) -> Response: env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id) # Run fast_import - if fast_import.extra_env is None: - fast_import.extra_env = {} - fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() - fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() - fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() - fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() - fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() - fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" + fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"}) pg_port = port_distributor.get_port() fast_import.run_pgdata(pg_port=pg_port, s3prefix=f"s3://{bucket}/{key_prefix}") @@ -693,14 +686,9 @@ def encrypt(x: str) -> EncryptResponseTypeDef: ) # Run fast_import - if fast_import.extra_env is None: - fast_import.extra_env = {} - fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() - fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() - fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() - fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() - fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() - fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" + fast_import.set_aws_creds( + mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"} + ) fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}") job_status_obj = mock_s3_client.get_object( @@ -715,7 +703,7 @@ def encrypt(x: str) -> EncryptResponseTypeDef: assert res[0][0] == 10 -def test_fast_import_restore_to_connstring_error_to_s3( +def test_fast_import_restore_to_connstring_error_to_s3_bad_destination( test_output_dir, vanilla_pg: VanillaPostgres, port_distributor: PortDistributor, @@ -760,14 +748,7 @@ def encrypt(x: str) -> EncryptResponseTypeDef: mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec)) # Run fast_import - if fast_import.extra_env is None: - fast_import.extra_env = {} - fast_import.extra_env["AWS_ACCESS_KEY_ID"] = mock_s3_server.access_key() - fast_import.extra_env["AWS_SECRET_ACCESS_KEY"] = mock_s3_server.secret_key() - fast_import.extra_env["AWS_SESSION_TOKEN"] = mock_s3_server.session_token() - fast_import.extra_env["AWS_REGION"] = mock_s3_server.region() - fast_import.extra_env["AWS_ENDPOINT_URL"] = mock_s3_server.endpoint() - fast_import.extra_env["RUST_LOG"] = "aws_config=debug,aws_sdk_kms=debug" + fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"}) fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}") job_status_obj = mock_s3_client.get_object( @@ -779,3 +760,51 @@ def encrypt(x: str) -> EncryptResponseTypeDef: "error": "pg_restore failed", }, f"got status: {job_status}" vanilla_pg.stop() + + +def test_fast_import_restore_to_connstring_error_to_s3_kms_error( + test_output_dir, + port_distributor: PortDistributor, + fast_import: FastImport, + pg_distrib_dir: Path, + pg_version: PgVersion, + mock_s3_server: MockS3Server, + mock_kms: KMSClient, + mock_s3_client: S3Client, +): + # Prepare KMS and S3 + key_response = mock_kms.create_key( + Description="Test key", + KeyUsage="ENCRYPT_DECRYPT", + Origin="AWS_KMS", + ) + key_id = key_response["KeyMetadata"]["KeyId"] + + def encrypt(x: str) -> EncryptResponseTypeDef: + return mock_kms.encrypt(KeyId=key_id, Plaintext=x) + + # Encrypt connstrings and put spec into S3 + spec = { + "encryption_secret": {"KMS": {"key_id": key_id}}, + "source_connstring_ciphertext_base64": base64.b64encode(b"invalid encrypted string").decode( + "utf-8" + ), + } + + bucket = "test-bucket" + key_prefix = "test-prefix" + mock_s3_client.create_bucket(Bucket=bucket) + mock_s3_client.put_object(Bucket=bucket, Key=f"{key_prefix}/spec.json", Body=json.dumps(spec)) + + # Run fast_import + fast_import.set_aws_creds(mock_s3_server, {"RUST_LOG": "aws_config=debug,aws_sdk_kms=debug"}) + fast_import.run_dump_restore(s3prefix=f"s3://{bucket}/{key_prefix}") + + job_status_obj = mock_s3_client.get_object( + Bucket=bucket, Key=f"{key_prefix}/status/fast_import" + ) + job_status = job_status_obj["Body"].read().decode("utf-8") + assert json.loads(job_status) == { + "done": False, + "error": "decrypt source connection string", + }, f"got status: {job_status}" From 3c292484ebe362ef0937609b0c6e4b8ac6511eff Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Tue, 18 Mar 2025 19:54:39 +0000 Subject: [PATCH 3/4] Added command to status file --- compute_tools/src/bin/fast_import.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 5844bcb0e577..891fcfd34874 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -45,7 +45,7 @@ mod s3_uri; const PG_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(600); const PG_WAIT_RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(300); -#[derive(Subcommand, Debug)] +#[derive(Subcommand, Debug, Clone, serde::Serialize)] enum Command { /// Runs local postgres (neon binary), restores into it, /// uploads pgdata to s3 to be consumed by pageservers @@ -649,7 +649,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), } - match args.command { + match args.command.clone() { Command::Pgdata { source_connection_string, interactive, @@ -700,10 +700,11 @@ pub(crate) async fn main() -> anyhow::Result<()> { std::fs::create_dir(&status_dir).context("create status directory")?; } let status_file = status_dir.join("fast_import"); - let res_obj = if res.is_ok() { - serde_json::json!({"done": true}) - } else { - serde_json::json!({"done": false, "error": res.unwrap_err().to_string()}) + let res_obj = match res { + Ok(_) => serde_json::json!({"command": args.command, "done": true}), + Err(err) => { + serde_json::json!({"command": args.command, "done": false, "error": err.to_string()}) + } }; std::fs::write(&status_file, res_obj.to_string()).context("write status file")?; aws_s3_sync::upload_dir_recursive( From c4ace56b7ad964db361913ea7049a99daf3b89d8 Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Tue, 18 Mar 2025 20:12:48 +0000 Subject: [PATCH 4/4] Retries in s3/kms clients, command as str --- compute_tools/src/bin/fast_import.rs | 26 ++++++++++++++++++++--- test_runner/regress/test_import_pgdata.py | 13 ++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 891fcfd34874..537028cde1b3 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -85,6 +85,15 @@ enum Command { }, } +impl Command { + fn as_str(&self) -> &'static str { + match self { + Command::Pgdata { .. } => "pgdata", + Command::DumpRestore { .. } => "dump-restore", + } + } +} + #[derive(clap::Parser)] struct Args { #[clap(long, env = "NEON_IMPORTER_WORKDIR")] @@ -604,7 +613,18 @@ pub(crate) async fn main() -> anyhow::Result<()> { // Initialize AWS clients only if s3_prefix is specified let (s3_client, kms_client) = if args.s3_prefix.is_some() { - let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; + // Create AWS config with enhanced retry settings + let config = aws_config::defaults(BehaviorVersion::v2024_03_28()) + .retry_config( + aws_config::retry::RetryConfig::standard() + .with_max_attempts(5) // Retry up to 5 times + .with_initial_backoff(std::time::Duration::from_millis(200)) // Start with 200ms delay + .with_max_backoff(std::time::Duration::from_secs(5)), // Cap at 5 seconds + ) + .load() + .await; + + // Create clients from the config with enhanced retry settings let s3_client = aws_sdk_s3::Client::new(&config); let kms = aws_sdk_kms::Client::new(&config); (Some(s3_client), Some(kms)) @@ -701,9 +721,9 @@ pub(crate) async fn main() -> anyhow::Result<()> { } let status_file = status_dir.join("fast_import"); let res_obj = match res { - Ok(_) => serde_json::json!({"command": args.command, "done": true}), + Ok(_) => serde_json::json!({"command": args.command.as_str(), "done": true}), Err(err) => { - serde_json::json!({"command": args.command, "done": false, "error": err.to_string()}) + serde_json::json!({"command": args.command.as_str(), "done": false, "error": err.to_string()}) } }; std::fs::write(&status_file, res_obj.to_string()).context("write status file")?; diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 9a58ba538600..a3ef75ddb06f 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -158,6 +158,7 @@ def validate_vanilla_equivalence(ep): statusdir = importbucket / "status" statusdir.mkdir() (statusdir / "pgdata").write_text(json.dumps({"done": True})) + (statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True})) # # Do the import @@ -451,7 +452,10 @@ def handler(request: Request) -> Response: Bucket=bucket, Key=f"{key_prefix}/status/fast_import" ) job_status = job_status_obj["Body"].read().decode("utf-8") - assert json.loads(job_status) == {"done": True}, f"got status: {job_status}" + assert json.loads(job_status) == { + "command": "pgdata", + "done": True, + }, f"got status: {job_status}" vanilla_pg.stop() @@ -695,7 +699,10 @@ def encrypt(x: str) -> EncryptResponseTypeDef: Bucket=bucket, Key=f"{key_prefix}/status/fast_import" ) job_status = job_status_obj["Body"].read().decode("utf-8") - assert json.loads(job_status) == {"done": True}, f"got status: {job_status}" + assert json.loads(job_status) == { + "done": True, + "command": "dump-restore", + }, f"got status: {job_status}" vanilla_pg.stop() res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;") @@ -756,6 +763,7 @@ def encrypt(x: str) -> EncryptResponseTypeDef: ) job_status = job_status_obj["Body"].read().decode("utf-8") assert json.loads(job_status) == { + "command": "dump-restore", "done": False, "error": "pg_restore failed", }, f"got status: {job_status}" @@ -805,6 +813,7 @@ def encrypt(x: str) -> EncryptResponseTypeDef: ) job_status = job_status_obj["Body"].read().decode("utf-8") assert json.loads(job_status) == { + "command": "dump-restore", "done": False, "error": "decrypt source connection string", }, f"got status: {job_status}"