diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a9e8518..099d2d3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,19 +66,6 @@ jobs: HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/ if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0' - - name: Check kernel conversion - run: | - uv pip install wheel - uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1 - uv pip install triton_layer_norm-0.0.1*.whl - uv run python -c "import triton_layer_norm" - - - name: Check kernel conversion (flat build) - run: | - uv run kernels to-wheel kernels-test/flattened-build 0.0.1 - uv pip install flattened_build-0.0.1*.whl - uv run python -c "import flattened_build" - - name: Check README generation # For now, just checks that generation doesn't fail. run: | diff --git a/docs/source/cli.md b/docs/source/cli.md index 65a014b..c7588c5 100644 --- a/docs/source/cli.md +++ b/docs/source/cli.md @@ -20,31 +20,6 @@ Checking variant: torch28-cxx11-cu128-aarch64-linux [...] ``` -### kernels to-wheel - -We strongly recommend downloading kernels from the Hub using the `kernels` -package, since this comes with large [benefits](index.md) over using Python -wheels. That said, some projects may require deployment of kernels as -wheels. The `kernels` utility provides a simple solution to this. You can -convert any Hub kernel into a set of wheels with the `to-wheel` command: - -```bash -$ kernels to-wheel drbh/img2grey 1.1.2 -☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl -☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl -☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl -☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl -☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl -☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl -``` - ### kernels upload Use `kernels upload --repo_id="hub-username/kernel"` to upload diff --git a/src/kernels/cli.py b/src/kernels/cli.py index 633b3e3..25dcad1 100644 --- a/src/kernels/cli.py +++ b/src/kernels/cli.py @@ -12,7 +12,6 @@ from kernels.utils import install_kernel, install_kernel_all_variants from .doc import generate_readme_for_kernel -from .wheel import build_variant_to_wheel BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-universal)") @@ -92,25 +91,6 @@ def main(): ) lock_parser.set_defaults(func=lock_kernels) - to_wheel_parser = subparsers.add_parser( - "to-wheel", help="Convert a kernel to a wheel file" - ) - to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID") - to_wheel_parser.add_argument("version", type=str, help="The kernel version") - to_wheel_parser.add_argument( - "--python-version", - type=str, - default="3.9", - help="The minimum Python version. Must match the Python version that the kernel was compiled for.", - ) - to_wheel_parser.add_argument( - "--manylinux-version", - type=str, - default="2.28", - help="The manylinux version. Must match the manylinux version that the kernel was compiled for.", - ) - to_wheel_parser.set_defaults(func=kernels_to_wheel) - # Add generate-readme subcommand parser generate_readme_parser = subparsers.add_parser( "generate-readme", @@ -174,24 +154,6 @@ def download_kernels(args): sys.exit(1) -def kernels_to_wheel(args): - variants_path = install_kernel_all_variants( - repo_id=args.repo_id, revision=f"v{args.version}" - ) - for variant_path in variants_path.iterdir(): - if not variant_path.is_dir(): - continue - wheel_path = build_variant_to_wheel( - manylinux_version=args.manylinux_version, - python_version=args.python_version, - repo_id=args.repo_id, - version=args.version, - variant_path=variant_path, - wheel_dir=Path("."), - ) - print(f"☸️ {wheel_path.name}", file=sys.stderr) - - def lock_kernels(args): with open(args.project_dir / "pyproject.toml", "rb") as f: data = tomllib.load(f) diff --git a/src/kernels/wheel.py b/src/kernels/wheel.py deleted file mode 100644 index bc33c61..0000000 --- a/src/kernels/wheel.py +++ /dev/null @@ -1,194 +0,0 @@ -import email.policy -import os -from dataclasses import dataclass -from email.message import Message -from importlib.metadata import PackageNotFoundError, version -from pathlib import Path -from typing import Optional - -try: - KERNELS_VERSION = version("kernels") -except PackageNotFoundError: - KERNELS_VERSION = "unknown" - - -@dataclass -class Metadata: - name: str - version: str - cuda_version: Optional[str] - cxx_abi_version: Optional[str] - torch_version: Optional[str] - os: Optional[str] - platform: Optional[str] - - @property - def is_universal(self) -> bool: - return self.platform is None - - -def build_variant_to_wheel( - repo_id: str, - *, - version: str, - variant_path: Path, - wheel_dir: Path, - manylinux_version: str = "2.28", - python_version: str = "3.9", -) -> Path: - """ - Create a wheel file from the variant path. - """ - name = repo_id.split("/")[-1].replace("_", "-") - metadata = extract_metadata(name, version, variant_path) - return build_wheel( - metadata, - variant_path=variant_path, - wheel_dir=wheel_dir, - manylinux_version=manylinux_version, - python_version=python_version, - ) - - -def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata: - """ - Extract metadata from the variant path. - """ - if variant_path.name == "torch-universal": - return Metadata( - name=name, - version=version, - cuda_version=None, - cxx_abi_version=None, - torch_version=None, - os=None, - platform=None, - ) - - if not variant_path.name.startswith("torch"): - raise ValueError("Currently only conversion of Torch kernels is supported.") - - variant_parts = variant_path.name.removeprefix("torch").split("-") - if len(variant_parts) != 5: - raise ValueError(f"Invalid variant name: {variant_path.name}") - - torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}" - cpp_abi_version = variant_parts[1].removeprefix("cxx") - cuda_version = variant_parts[2].removeprefix("cu") - platform = variant_parts[3].replace("-", "_") - os = variant_parts[4] - - return Metadata( - name=name, - version=version, - cuda_version=cuda_version, - cxx_abi_version=cpp_abi_version, - torch_version=torch_version, - os=os, - platform=platform, - ) - - -def build_wheel( - metadata: Metadata, - *, - variant_path: Path, - wheel_dir: Path, - manylinux_version: str = "2.28", - python_version: str = "3.9", -) -> Path: - """ - Build the wheel file. - """ - try: - from wheel.wheelfile import WheelFile # type: ignore - except ImportError: - raise ImportError( - "The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`" - ) - - name = metadata.name.replace("-", "_") - python_version_flat = python_version.replace(".", "") - - if metadata.is_universal: - python_tag = f"py{python_version_flat}" - abi_tag = "none" - platform_tag = "any" - wheel_filename = ( - f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl" - ) - dist_info_dir_name = f"{name}-{metadata.version}.dist-info" - root_is_purelib = "true" - requires_dist_torch = "torch" - else: - python_tag = f"cp{python_version_flat}" - abi_tag = "abi3" - - if ( - metadata.torch_version is None - or metadata.cuda_version is None - or metadata.cxx_abi_version is None - or metadata.os is None - or metadata.platform is None - ): - raise ValueError( - "Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels." - ) - - local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}" - - if metadata.os == "linux": - platform_tag = ( - f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}" - ) - else: - platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}" - - wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl" - dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info" - root_is_purelib = "false" - requires_dist_torch = f"torch=={metadata.torch_version}.*" - - wheel_path = wheel_dir / wheel_filename - - wheel_msg = Message(email.policy.compat32) - wheel_msg.add_header("Wheel-Version", "1.0") - wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})") - wheel_msg.add_header("Root-Is-Purelib", root_is_purelib) - wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}") - - metadata_msg = Message(email.policy.compat32) - metadata_msg.add_header("Metadata-Version", "2.1") - metadata_msg.add_header("Name", name) - metadata_msg.add_header("Version", metadata.version) - metadata_msg.add_header("Summary", f"{name} kernel") - metadata_msg.add_header("Requires-Python", ">=3.9") - metadata_msg.add_header("Requires-Dist", requires_dist_torch) - - # Check if the kernel uses a flat build. - if (variant_path / "__init__.py").exists(): - flat_build = True - source_pkg_dir = variant_path - else: - flat_build = False - source_pkg_dir = variant_path / name - - with WheelFile(wheel_path, "w") as wheel_file: - for root, dirnames, filenames in os.walk(source_pkg_dir): - for filename in filenames: - if filename.endswith(".pyc"): - continue - - abs_filepath = os.path.join(root, filename) - entry_name = os.path.relpath(abs_filepath, variant_path) - if flat_build: - entry_name = os.path.join(name, entry_name) - wheel_file.write(abs_filepath, entry_name) - - wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL") - wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8")) - - metadata_path = os.path.join(dist_info_dir_name, "METADATA") - wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8")) - - return wheel_path