Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 242 additions & 7 deletions build_tools/packaging/download_prerelease_packages.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to rename this script?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest renaming to download_release_artifacts.py. As this is not only for prereleases.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it's not artifacts, those are rather the tarballs we push to the artifact buckets. It is rather download_python_packages isn't it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that is more accurate, changing the script name to download_python_packages.py

Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"rocm_sdk_core",
"rocm_sdk_devel",
"rocm_sdk_libraries-*",
"rocm_profiler",
"torch",
"torchaudio",
"torchvision",
Expand All @@ -140,6 +141,24 @@
"jaxlib",
}

PACKAGES_TO_PROMOTE_MULTI_ARCH = {
"torch",
"torchaudio",
"torchvision",
"triton",
"apex",
"rocm",
"rocm_sdk_core",
"rocm_sdk_devel",
"rocm_sdk_libraries",
"rocm_profiler",
"rocm_bootstrap",
# device packages (IMPORTANT)
"amd_torch_device",
"amd_torchvision_device",
"rocm_sdk_device",
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by: we may want a placeholder here for amd_torchaudio_device (and in a few other scripts/workflows) so we remember if pytorch/audio#4180 gets further along

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with a placeholder


# copied from build_tools/third_party/s3_management/update_dependencies.py PACKAGES_PER_PROJECT
# Note: replace - in package names with _ to match the filename patterns in S3
DEPENDENCY_PACKAGES = {
Expand Down Expand Up @@ -167,6 +186,33 @@
}


def is_allowed_multi_arch_package(filename: str) -> bool:
"""
Check if filename belongs to allowed packages for multi-arch flow.
Uses BOTH:
- PACKAGES_TO_PROMOTE
- PACKAGES_TO_PROMOTE_MULTI_ARCH

Supports:
- exact match
- wildcard '*' suffix (legacy)
- prefix match (for device packages)
"""
base = filename.split("-")[0]

def matches(pkg: str) -> bool:
# Handle wildcard like "rocm_sdk_libraries-*"
if pkg.endswith("*"):
return base.startswith(pkg[:-1])

# Exact match OR prefix match (for device families)
return base == pkg or base.startswith(pkg)

return any(
matches(pkg) for pkg in (PACKAGES_TO_PROMOTE | PACKAGES_TO_PROMOTE_MULTI_ARCH)
)


def categorize_package(filename: str) -> str:
"""Categorize a package file.

Expand Down Expand Up @@ -263,6 +309,70 @@ def has_version_in_arch(
return False


def has_version_in_directory(s3_client, bucket, prefix, directory, version):
paginator = s3_client.get_paginator("list_objects_v2")

if directory is None:
prefix_to_use = prefix
else:
prefix_to_use = f"{prefix}{directory}/"

pages = paginator.paginate(
Bucket=bucket,
Prefix=prefix_to_use,
)

for page in pages:
if "Contents" not in page:
continue

for obj in page["Contents"]:
if version in obj["Key"]:
return True

return False


def list_packages_multi_arch_verbose(s3_client, bucket, prefix, version):
paginator = s3_client.get_paginator("list_objects_v2")

total_size = 0
found = False

print("\nPackages")
print("-" * 60)

pages = paginator.paginate(
Bucket=bucket,
Prefix=prefix,
)

for page in pages:
if "Contents" not in page:
continue

for obj in page["Contents"]:
key = obj["Key"]
filename = key.split("/")[-1]

if not filename or filename == "index.html":
continue

if version in filename and is_allowed_multi_arch_package(filename):
found = True
size = obj["Size"]

print(f" - {filename} ({size / BYTES_TO_MB:.2f} MB)")
total_size += size

if not found:
print(f"[ERROR]: No packages found for version {version}")
sys.exit(1)

print("\n" + "=" * 60)
print(f"TOTAL SIZE: {total_size / BYTES_TO_MB:.2f} MB")


def list_packages_for_arch(
s3_client, bucket_name: str, bucket_prefix: str, arch: str, version: str
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]], List[Tuple[str, int]]]:
Expand Down Expand Up @@ -387,6 +497,64 @@ def download_file(s3_client, bucket_name: str, key: str, local_path: Path) -> bo
return False


def download_multi_arch_packages(
s3_client,
bucket,
prefix,
version,
output_dir,
):
paginator = s3_client.get_paginator("list_objects_v2")

wheels_dir = output_dir / "wheels"
wheels_dir.mkdir(parents=True, exist_ok=True)

total_success = 0
total_fail = 0

pages = paginator.paginate(
Bucket=bucket,
Prefix=prefix,
)

print("\nDownloading packages")
print("=" * 80)

for page in pages:
if "Contents" not in page:
continue

for obj in page["Contents"]:
key = obj["Key"]
filename = key.split("/")[-1]

if not filename or filename == "index.html":
continue

if version not in filename:
continue

if not is_allowed_multi_arch_package(filename):
continue

local_path = wheels_dir / filename

if local_path.exists():
print(f" SKIP (exists): {filename}")
total_success += 1
continue

size = obj["Size"]
print(f" Downloading: {filename} ({size / BYTES_TO_MB:.2f} MB)")

if download_file(s3_client, bucket, key, local_path):
total_success += 1
else:
total_fail += 1

return total_success, total_fail


def download_packages(
s3_client,
bucket_name: str,
Expand Down Expand Up @@ -423,7 +591,7 @@ def download_packages(
if unknown:
print(f" Unknown packages (skipped): {len(unknown)}")
for key in unknown:
print(f" - {key[0].split("/")[-1]}")
print(f" - {key[0].split('/')[-1]}")
print("")
print("-" * 80)

Expand Down Expand Up @@ -632,6 +800,20 @@ def parse_arguments(argv):
help="List all packages per architecture, do not download",
)

parser.add_argument(
"--multi-arch",
action=argparse.BooleanOptionalAction,
default=False,
help="--multi-arch requires prefix like v4/whl/",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This help is confusing since the option is boolean.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the help text here.

)

parser.add_argument(
"--list-packages-multi-arch",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"--list-packages-multi-arch",
"--list-multi-arch-packages",

?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done updated

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also updated the tests with the argument change

https://gist.github.com/araravik-psd/58a570bccd944fc8a5df7a96d6c6abe3

action=argparse.BooleanOptionalAction,
default=False,
help="List all multi-arch packages matching version",
)

args = parser.parse_args(argv)

if args.arch:
Expand All @@ -647,9 +829,14 @@ def parse_arguments(argv):
if args.tarball_output_dir is not None
else None
)
if not args.list_archs and not args.list_packages_per_arch and not args.output_dir:
if (
not args.list_archs
and not args.list_packages_per_arch
and not args.list_packages_multi_arch
and not args.output_dir
):
parser.error(
"--output-dir is required unless --list-archs or --list-packages-per-arch is specified"
"--output-dir is required unless --list-archs, --list-packages-per-arch, or --list-packages-multi-arch is specified"
)

return args
Expand Down Expand Up @@ -727,7 +914,7 @@ def print_packages_per_arch(
size_tarball += tarball_size
total_size_tarball += tarball_size
size_mb = tarball_size / BYTES_TO_MB
print(f" - {tarball_name.split("/")[-1]} ({size_mb:.2f} MB)")
print(f" - {tarball_name.split('/')[-1]} ({size_mb:.2f} MB)")
if not tarballs:
print(
f" [WARN]: No tarball found for {arch} with version {version}. Skipping!"
Expand Down Expand Up @@ -771,6 +958,8 @@ def download_prerelease_packages(
bucket_name: str = "therock-prerelease-python",
bucket_prefix: str = "v3/whl/",
include_dependencies: bool = False,
multi_arch: bool = False,
list_packages_multi_arch: bool = False,
include_tarballs: bool = False,
tarball_bucket_name: str = "therock-prerelease-tarball",
tarball_bucket_prefix: str = "v3/tarball/",
Expand Down Expand Up @@ -804,10 +993,14 @@ def download_prerelease_packages(
Raises:
SystemExit: If AWS credentials are not configured, no architectures found, or downloads fail
"""
# Validate arguments
if not list_archs and not list_packages_per_arch and output_dir is None:
if (
not list_archs
and not list_packages_per_arch
and not list_packages_multi_arch
and output_dir is None
):
print(
"[ERROR]: output_dir is required unless list_archs=True or list_packages_per_arch=True"
"[ERROR]: output_dir is required unless list_archs=True, list_packages_per_arch=True, or list_packages_multi_arch=True"
)
sys.exit(1)

Expand All @@ -816,13 +1009,53 @@ def download_prerelease_packages(
print("=" * 80)
print(f"Bucket: {bucket_name}")
print(f"Version: {version}")

if architectures:
print(f"Architectures: {architectures} (user-specified)")
else:
print(f"Architecture: ALL")
print("=" * 80)

s3_client = boto3.client("s3")
if multi_arch:
# Validate that packages exist
if not has_version_in_directory(
s3_client, bucket_name, bucket_prefix, None, version
):
print(f"[ERROR]: No packages found for version {version}")
sys.exit(1)

if list_packages_multi_arch:
list_packages_multi_arch_verbose(
s3_client,
bucket_name,
bucket_prefix,
version,
)
return

output_dir.mkdir(parents=True, exist_ok=True)
print(f"\nOutput directory: {output_dir.absolute()}")

success, fail = download_multi_arch_packages(
s3_client,
bucket_name,
bucket_prefix,
version,
output_dir,
)

print("\n" + "=" * 80)
print("DOWNLOAD COMPLETE (MULTI-ARCH)")
print("=" * 80)
print(f"Total successful downloads: {success}")
print(f"Total failed downloads: {fail}")

if fail > 0:
sys.exit(1)

return

# List architectures
if architectures:

Expand Down Expand Up @@ -937,6 +1170,8 @@ def download_prerelease_packages(
output_dir=args.output_dir,
architectures=args.arch,
bucket_name=args.bucket,
multi_arch=args.multi_arch,
list_packages_multi_arch=args.list_packages_multi_arch,
bucket_prefix=args.bucket_prefix,
include_dependencies=args.include_dependencies,
include_tarballs=args.include_tarballs,
Expand Down
Loading