Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 140 additions & 256 deletions docker/Dockerfile.rocm_MI350-5

Large diffs are not rendered by default.

41 changes: 5 additions & 36 deletions docker/amd_patch/latest/megatron.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py
index 87cceac3..ac686d74 100644
--- a/megatron/legacy/fused_kernels/__init__.py
+++ b/megatron/legacy/fused_kernels/__init__.py
@@ -3,6 +3,7 @@
Expand All @@ -10,42 +9,12 @@ index 87cceac3..ac686d74 100644

from torch.utils import cpp_extension

@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
@@ -15,6 +16,8 @@


def load(args):
-
- # Check if cuda 11 is installed for compute capability 8.0
- cc_flag = []
- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
- cpp_extension.CUDA_HOME
- )
- if int(bare_metal_major) >= 11:
- cc_flag.append('-gencode')
- cc_flag.append('arch=compute_80,code=sm_80')
- if int(bare_metal_minor) >= 8:
+ if torch.cuda.is_available() and torch.version.cuda:
+ # Check if cuda 11 is installed for compute capability 8.0
+ cc_flag = []
+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
+ cpp_extension.CUDA_HOME
+ )
+ if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
- cc_flag.append('arch=compute_90,code=sm_90')
+ cc_flag.append('arch=compute_80,code=sm_80')
+ if int(bare_metal_minor) >= 8:
+ cc_flag.append('-gencode')
+ cc_flag.append('arch=compute_90,code=sm_90')
+ if not torch.version.cuda:
+ return

- # Build path
- srcpath = pathlib.Path(__file__).parent.absolute()
- buildpath = srcpath / "build"
- _create_build_dir(buildpath)
+ # Build path
+ srcpath = pathlib.Path(__file__).parent.absolute()
+ buildpath = srcpath / "build"
+ _create_build_dir(buildpath)

# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
38 changes: 0 additions & 38 deletions docker/amd_patch/latest/sglang.patch

This file was deleted.

This file was deleted.

Loading
Loading