-
Notifications
You must be signed in to change notification settings - Fork 236
Enable multi-arch package download and filtering in download_prerelease_packages.py #5028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
694ecf0
f0bdf2d
e23b325
7c3e7d3
005e425
2052a3c
92bf4b5
f03bf77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -130,6 +130,7 @@ | |||||
| "rocm_sdk_core", | ||||||
| "rocm_sdk_devel", | ||||||
| "rocm_sdk_libraries-*", | ||||||
| "rocm_profiler", | ||||||
| "torch", | ||||||
| "torchaudio", | ||||||
| "torchvision", | ||||||
|
|
@@ -140,6 +141,24 @@ | |||||
| "jaxlib", | ||||||
| } | ||||||
|
|
||||||
| PACKAGES_TO_PROMOTE_MULTI_ARCH = { | ||||||
| "torch", | ||||||
| "torchaudio", | ||||||
| "torchvision", | ||||||
| "triton", | ||||||
| "apex", | ||||||
| "rocm", | ||||||
| "rocm_sdk_core", | ||||||
| "rocm_sdk_devel", | ||||||
| "rocm_sdk_libraries", | ||||||
| "rocm_profiler", | ||||||
| "rocm_bootstrap", | ||||||
| # device packages (IMPORTANT) | ||||||
| "amd_torch_device", | ||||||
| "amd_torchvision_device", | ||||||
| "rocm_sdk_device", | ||||||
| } | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drive-by: we may want a placeholder here for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated with a placeholder |
||||||
|
|
||||||
| # copied from build_tools/third_party/s3_management/update_dependencies.py PACKAGES_PER_PROJECT | ||||||
| # Note: replace - in package names with _ to match the filename patterns in S3 | ||||||
| DEPENDENCY_PACKAGES = { | ||||||
|
|
@@ -167,6 +186,33 @@ | |||||
| } | ||||||
|
|
||||||
|
|
||||||
| def is_allowed_multi_arch_package(filename: str) -> bool: | ||||||
| """ | ||||||
| Check if filename belongs to allowed packages for multi-arch flow. | ||||||
| Uses BOTH: | ||||||
| - PACKAGES_TO_PROMOTE | ||||||
| - PACKAGES_TO_PROMOTE_MULTI_ARCH | ||||||
|
|
||||||
| Supports: | ||||||
| - exact match | ||||||
| - wildcard '*' suffix (legacy) | ||||||
| - prefix match (for device packages) | ||||||
| """ | ||||||
| base = filename.split("-")[0] | ||||||
|
|
||||||
| def matches(pkg: str) -> bool: | ||||||
| # Handle wildcard like "rocm_sdk_libraries-*" | ||||||
| if pkg.endswith("*"): | ||||||
| return base.startswith(pkg[:-1]) | ||||||
|
|
||||||
| # Exact match OR prefix match (for device families) | ||||||
| return base == pkg or base.startswith(pkg) | ||||||
|
|
||||||
| return any( | ||||||
| matches(pkg) for pkg in (PACKAGES_TO_PROMOTE | PACKAGES_TO_PROMOTE_MULTI_ARCH) | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def categorize_package(filename: str) -> str: | ||||||
| """Categorize a package file. | ||||||
|
|
||||||
|
|
@@ -263,6 +309,70 @@ def has_version_in_arch( | |||||
| return False | ||||||
|
|
||||||
|
|
||||||
| def has_version_in_directory(s3_client, bucket, prefix, directory, version): | ||||||
| paginator = s3_client.get_paginator("list_objects_v2") | ||||||
|
|
||||||
| if directory is None: | ||||||
| prefix_to_use = prefix | ||||||
| else: | ||||||
| prefix_to_use = f"{prefix}{directory}/" | ||||||
|
|
||||||
| pages = paginator.paginate( | ||||||
| Bucket=bucket, | ||||||
| Prefix=prefix_to_use, | ||||||
| ) | ||||||
|
|
||||||
| for page in pages: | ||||||
| if "Contents" not in page: | ||||||
| continue | ||||||
|
|
||||||
| for obj in page["Contents"]: | ||||||
| if version in obj["Key"]: | ||||||
| return True | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
|
|
||||||
| def list_packages_multi_arch_verbose(s3_client, bucket, prefix, version): | ||||||
| paginator = s3_client.get_paginator("list_objects_v2") | ||||||
|
|
||||||
| total_size = 0 | ||||||
| found = False | ||||||
|
|
||||||
| print("\nPackages") | ||||||
| print("-" * 60) | ||||||
|
|
||||||
| pages = paginator.paginate( | ||||||
| Bucket=bucket, | ||||||
| Prefix=prefix, | ||||||
| ) | ||||||
|
|
||||||
| for page in pages: | ||||||
| if "Contents" not in page: | ||||||
| continue | ||||||
|
|
||||||
| for obj in page["Contents"]: | ||||||
| key = obj["Key"] | ||||||
| filename = key.split("/")[-1] | ||||||
|
|
||||||
| if not filename or filename == "index.html": | ||||||
| continue | ||||||
|
|
||||||
| if version in filename and is_allowed_multi_arch_package(filename): | ||||||
| found = True | ||||||
| size = obj["Size"] | ||||||
|
|
||||||
| print(f" - {filename} ({size / BYTES_TO_MB:.2f} MB)") | ||||||
| total_size += size | ||||||
|
|
||||||
| if not found: | ||||||
| print(f"[ERROR]: No packages found for version {version}") | ||||||
| sys.exit(1) | ||||||
|
|
||||||
| print("\n" + "=" * 60) | ||||||
| print(f"TOTAL SIZE: {total_size / BYTES_TO_MB:.2f} MB") | ||||||
|
|
||||||
|
|
||||||
| def list_packages_for_arch( | ||||||
| s3_client, bucket_name: str, bucket_prefix: str, arch: str, version: str | ||||||
| ) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]], List[Tuple[str, int]]]: | ||||||
|
|
@@ -387,6 +497,64 @@ def download_file(s3_client, bucket_name: str, key: str, local_path: Path) -> bo | |||||
| return False | ||||||
|
|
||||||
|
|
||||||
| def download_multi_arch_packages( | ||||||
| s3_client, | ||||||
| bucket, | ||||||
| prefix, | ||||||
| version, | ||||||
| output_dir, | ||||||
| ): | ||||||
| paginator = s3_client.get_paginator("list_objects_v2") | ||||||
|
|
||||||
| wheels_dir = output_dir / "wheels" | ||||||
| wheels_dir.mkdir(parents=True, exist_ok=True) | ||||||
|
|
||||||
| total_success = 0 | ||||||
| total_fail = 0 | ||||||
|
|
||||||
| pages = paginator.paginate( | ||||||
| Bucket=bucket, | ||||||
| Prefix=prefix, | ||||||
| ) | ||||||
|
|
||||||
| print("\nDownloading packages") | ||||||
| print("=" * 80) | ||||||
|
|
||||||
| for page in pages: | ||||||
| if "Contents" not in page: | ||||||
| continue | ||||||
|
|
||||||
| for obj in page["Contents"]: | ||||||
| key = obj["Key"] | ||||||
| filename = key.split("/")[-1] | ||||||
|
|
||||||
| if not filename or filename == "index.html": | ||||||
| continue | ||||||
|
|
||||||
| if version not in filename: | ||||||
| continue | ||||||
|
|
||||||
| if not is_allowed_multi_arch_package(filename): | ||||||
| continue | ||||||
|
|
||||||
| local_path = wheels_dir / filename | ||||||
|
|
||||||
| if local_path.exists(): | ||||||
| print(f" SKIP (exists): {filename}") | ||||||
| total_success += 1 | ||||||
| continue | ||||||
|
|
||||||
| size = obj["Size"] | ||||||
| print(f" Downloading: {filename} ({size / BYTES_TO_MB:.2f} MB)") | ||||||
|
|
||||||
| if download_file(s3_client, bucket, key, local_path): | ||||||
| total_success += 1 | ||||||
| else: | ||||||
| total_fail += 1 | ||||||
|
|
||||||
| return total_success, total_fail | ||||||
|
|
||||||
|
|
||||||
| def download_packages( | ||||||
| s3_client, | ||||||
| bucket_name: str, | ||||||
|
|
@@ -423,7 +591,7 @@ def download_packages( | |||||
| if unknown: | ||||||
| print(f" Unknown packages (skipped): {len(unknown)}") | ||||||
| for key in unknown: | ||||||
| print(f" - {key[0].split("/")[-1]}") | ||||||
| print(f" - {key[0].split('/')[-1]}") | ||||||
| print("") | ||||||
| print("-" * 80) | ||||||
|
|
||||||
|
|
@@ -632,6 +800,20 @@ def parse_arguments(argv): | |||||
| help="List all packages per architecture, do not download", | ||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--multi-arch", | ||||||
| action=argparse.BooleanOptionalAction, | ||||||
| default=False, | ||||||
| help="--multi-arch requires prefix like v4/whl/", | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This help is confusing since the option is boolean.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the help text here. |
||||||
| ) | ||||||
|
|
||||||
| parser.add_argument( | ||||||
| "--list-packages-multi-arch", | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done updated
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also updated the tests with the argument change https://gist.github.com/araravik-psd/58a570bccd944fc8a5df7a96d6c6abe3 |
||||||
| action=argparse.BooleanOptionalAction, | ||||||
| default=False, | ||||||
| help="List all multi-arch packages matching version", | ||||||
| ) | ||||||
|
|
||||||
| args = parser.parse_args(argv) | ||||||
|
|
||||||
| if args.arch: | ||||||
|
|
@@ -647,9 +829,14 @@ def parse_arguments(argv): | |||||
| if args.tarball_output_dir is not None | ||||||
| else None | ||||||
| ) | ||||||
| if not args.list_archs and not args.list_packages_per_arch and not args.output_dir: | ||||||
| if ( | ||||||
| not args.list_archs | ||||||
| and not args.list_packages_per_arch | ||||||
| and not args.list_packages_multi_arch | ||||||
| and not args.output_dir | ||||||
| ): | ||||||
| parser.error( | ||||||
| "--output-dir is required unless --list-archs or --list-packages-per-arch is specified" | ||||||
| "--output-dir is required unless --list-archs, --list-packages-per-arch, or --list-packages-multi-arch is specified" | ||||||
| ) | ||||||
|
|
||||||
| return args | ||||||
|
|
@@ -727,7 +914,7 @@ def print_packages_per_arch( | |||||
| size_tarball += tarball_size | ||||||
| total_size_tarball += tarball_size | ||||||
| size_mb = tarball_size / BYTES_TO_MB | ||||||
| print(f" - {tarball_name.split("/")[-1]} ({size_mb:.2f} MB)") | ||||||
| print(f" - {tarball_name.split('/')[-1]} ({size_mb:.2f} MB)") | ||||||
| if not tarballs: | ||||||
| print( | ||||||
| f" [WARN]: No tarball found for {arch} with version {version}. Skipping!" | ||||||
|
|
@@ -771,6 +958,8 @@ def download_prerelease_packages( | |||||
| bucket_name: str = "therock-prerelease-python", | ||||||
| bucket_prefix: str = "v3/whl/", | ||||||
| include_dependencies: bool = False, | ||||||
| multi_arch: bool = False, | ||||||
| list_packages_multi_arch: bool = False, | ||||||
| include_tarballs: bool = False, | ||||||
| tarball_bucket_name: str = "therock-prerelease-tarball", | ||||||
| tarball_bucket_prefix: str = "v3/tarball/", | ||||||
|
|
@@ -804,10 +993,14 @@ def download_prerelease_packages( | |||||
| Raises: | ||||||
| SystemExit: If AWS credentials are not configured, no architectures found, or downloads fail | ||||||
| """ | ||||||
| # Validate arguments | ||||||
| if not list_archs and not list_packages_per_arch and output_dir is None: | ||||||
| if ( | ||||||
| not list_archs | ||||||
| and not list_packages_per_arch | ||||||
| and not list_packages_multi_arch | ||||||
| and output_dir is None | ||||||
| ): | ||||||
| print( | ||||||
| "[ERROR]: output_dir is required unless list_archs=True or list_packages_per_arch=True" | ||||||
| "[ERROR]: output_dir is required unless list_archs=True, list_packages_per_arch=True, or list_packages_multi_arch=True" | ||||||
| ) | ||||||
| sys.exit(1) | ||||||
|
|
||||||
|
|
@@ -816,13 +1009,53 @@ def download_prerelease_packages( | |||||
| print("=" * 80) | ||||||
| print(f"Bucket: {bucket_name}") | ||||||
| print(f"Version: {version}") | ||||||
|
|
||||||
| if architectures: | ||||||
| print(f"Architectures: {architectures} (user-specified)") | ||||||
| else: | ||||||
| print(f"Architecture: ALL") | ||||||
| print("=" * 80) | ||||||
|
|
||||||
| s3_client = boto3.client("s3") | ||||||
| if multi_arch: | ||||||
| # Validate that packages exist | ||||||
| if not has_version_in_directory( | ||||||
| s3_client, bucket_name, bucket_prefix, None, version | ||||||
| ): | ||||||
| print(f"[ERROR]: No packages found for version {version}") | ||||||
| sys.exit(1) | ||||||
|
|
||||||
| if list_packages_multi_arch: | ||||||
| list_packages_multi_arch_verbose( | ||||||
| s3_client, | ||||||
| bucket_name, | ||||||
| bucket_prefix, | ||||||
| version, | ||||||
| ) | ||||||
| return | ||||||
|
|
||||||
| output_dir.mkdir(parents=True, exist_ok=True) | ||||||
| print(f"\nOutput directory: {output_dir.absolute()}") | ||||||
|
|
||||||
| success, fail = download_multi_arch_packages( | ||||||
| s3_client, | ||||||
| bucket_name, | ||||||
| bucket_prefix, | ||||||
| version, | ||||||
| output_dir, | ||||||
| ) | ||||||
|
|
||||||
| print("\n" + "=" * 80) | ||||||
| print("DOWNLOAD COMPLETE (MULTI-ARCH)") | ||||||
| print("=" * 80) | ||||||
| print(f"Total successful downloads: {success}") | ||||||
| print(f"Total failed downloads: {fail}") | ||||||
|
|
||||||
| if fail > 0: | ||||||
| sys.exit(1) | ||||||
|
|
||||||
| return | ||||||
|
|
||||||
| # List architectures | ||||||
| if architectures: | ||||||
|
|
||||||
|
|
@@ -937,6 +1170,8 @@ def download_prerelease_packages( | |||||
| output_dir=args.output_dir, | ||||||
| architectures=args.arch, | ||||||
| bucket_name=args.bucket, | ||||||
| multi_arch=args.multi_arch, | ||||||
| list_packages_multi_arch=args.list_packages_multi_arch, | ||||||
| bucket_prefix=args.bucket_prefix, | ||||||
| include_dependencies=args.include_dependencies, | ||||||
| include_tarballs=args.include_tarballs, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to rename this script?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest renaming to download_release_artifacts.py. As this is not only for prereleases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it's not artifacts, those are rather the tarballs we push to the artifact buckets. It is rather download_python_packages isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes that is more accurate, changing the script name to download_python_packages.py