diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py index a7803cf1b0..eada14b2a5 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -28,6 +28,9 @@ # Needed since changing args to function causes recompiles torch._dynamo.config.cache_size_limit = 1000 +# Dynamic shapes hurt performance +torch._dynamo.config.automatic_dynamic_shapes = False + @dataclass(frozen=True) class ExperimentConfig: diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 377d86c7c9..9c044b9fef 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -213,7 +213,7 @@ def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool @pytest.mark.parametrize( "m,k,n_groups", [(256, 256, 4), (16640, 5120, 16), (16640, 8192, 16)] ) -def test_mxfp8_per_group_blocked_scales_2d( +def test_triton_mx_block_rearrange_2d_M_groups( m: int, k: int, n_groups: int, @@ -272,10 +272,10 @@ def test_mxfp8_per_group_blocked_scales_3d( @skip_if_rocm("ROCm enablement in progress") -@pytest.mark.parametrize("m", [256, 512, 1024, 5120]) -@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384]) -@pytest.mark.parametrize("n_groups", [1, 4, 8, 16]) -def test_mxfp8_per_group_blocked_scales_2d2d( +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("total_k", [512]) +@pytest.mark.parametrize("n_groups", [1]) +def test_triton_mx_block_rearrange_2d_K_groups( m: int, total_k: int, n_groups: int, diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 4aef7d3e92..95271dc2cb 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -34,10 +34,7 @@ @pytest.mark.parametrize( "target_fqns", - [ - ["experts"], - ["does.not.exist"], - ], + [["experts"], ["experts,shared_expert"], ["invalid.fqns"]], ) @pytest.mark.parametrize("compile", [False, True]) @pytest.mark.parametrize( diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py index e5d5cf439a..9ba1589d13 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py @@ -4,6 +4,7 @@ import triton import triton.language as tl from torch import Tensor +from torch.library import triton_op, wrap_triton from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ceil_div @@ -29,7 +30,14 @@ def torch_to_blocked_2d_M_groups( assert x_scales.ndim == 2, "x_scales must be 2D" assert block_size == 32, "Only block_size=32 is supported for now" - blocked_scales_list = [] + total_M, _ = x_scales.shape + num_groups = group_offs.shape[0] + + # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group, + # the Triton kernenl will use an upper bound of adding 128 padding rows to each group. + # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl). + total_M_padded = total_M + num_groups * 128 + blocked_scales = x_scales.new_zeros(total_M_padded, K // block_size) start_row_after_padding_list = [0] group_start_idx = 0 for i, group_end_idx in enumerate(group_offs.tolist()): @@ -42,7 +50,6 @@ def torch_to_blocked_2d_M_groups( # Convert group scales to blocked format group_scales = x_scales[group_start_idx:group_end_idx] group_scales_blocked = to_blocked(group_scales) - blocked_scales_list.append(group_scales_blocked) # Calculate the start row after padding scaling_groups_per_row = K // block_size @@ -50,11 +57,17 @@ def torch_to_blocked_2d_M_groups( new_start_row = prev_start_row_after_padding + rows_for_group start_row_after_padding_list.append(new_start_row) + # Write output to subtensor + group_rows_padded = ceil_div(group_size, 128) * 128 + blocked_scales[ + prev_start_row_after_padding : prev_start_row_after_padding + + group_rows_padded, + :, + ] = group_scales_blocked.reshape(-1, K // block_size) + # Update next group start index group_start_idx = group_end_idx - blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous() - blocked_scales = blocked_scales.reshape(-1, K // 32) start_row_after_padding = torch.tensor( start_row_after_padding_list, device=x_scales.device, dtype=torch.int64 ) @@ -79,34 +92,44 @@ def torch_to_blocked_2d_K_groups( """ assert x_scales.ndim == 2, "x_scales must be 2D" assert block_size == 32, "Only block_size=32 is supported for now" - blocked_scales_list = [] + M, total_K = x_scales.shape + padded_M = ceil_div(M, 128) * 128 + num_groups = group_offs.shape[0] + + # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group, + # Triton kernel will use an upper bound of adding 4 padding cols to each group. + # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl). + total_K_padded = total_K + num_groups * 4 + blocked_scales = x_scales.new_zeros(padded_M, total_K_padded) + start_col_after_padding_list = [0] group_start_idx = 0 for i, group_end_idx in enumerate(group_offs.tolist()): group_size = group_end_idx - group_start_idx - prev_start_row_after_padding = start_col_after_padding_list[i] + prev_start_col_after_padding = start_col_after_padding_list[i] if group_size == 0: - start_col_after_padding_list.append(prev_start_row_after_padding) + start_col_after_padding_list.append(prev_start_col_after_padding) continue # Convert group scales to blocked format group_scales = x_scales[:, group_start_idx:group_end_idx] group_scales_blocked = to_blocked(group_scales) cols_after_padding = ceil_div(group_size, 4) * 4 - blocked_scales_list.append(group_scales_blocked) + + # Write output to subtensor + blocked_scales[ + :, + prev_start_col_after_padding : prev_start_col_after_padding + + cols_after_padding, + ] = group_scales_blocked.reshape(-1, cols_after_padding) # Calculate the start row after padding - new_start_col = prev_start_row_after_padding + cols_after_padding + new_start_col = prev_start_col_after_padding + cols_after_padding start_col_after_padding_list.append(new_start_col) # Update next group start index group_start_idx = group_end_idx - # blocked_scales = torch.cat(blocked_scales_list, dim=1) - M = x_scales.shape[0] - padded_M = ceil_div(M, 128) * 128 - blocked_scales = torch.cat(blocked_scales_list) - blocked_scales = blocked_scales.reshape(padded_M, -1) start_cols_after_padding = torch.tensor( start_col_after_padding_list, device=x_scales.device, dtype=torch.int64 ) @@ -192,6 +215,7 @@ def compute_blocked_scale_offsets_for_K_groups( return group_sizes, starting_col_after_padding +@triton_op("torchao::triton_mx_block_rearrange_2d_M_groups", mutates_args={}) def triton_mx_block_rearrange_2d_M_groups( scales_tensor: torch.Tensor, input_group_end_offsets: torch.Tensor, @@ -216,14 +240,14 @@ def triton_mx_block_rearrange_2d_M_groups( "Expected element size to be 1 byte (8 bits)" ) rows, cols = scales_tensor.shape - num_groups = input_group_end_offsets.numel() + num_groups = input_group_end_offsets.shape[0] # Final offset is the total number of rows in the tensor - padded_rows = output_group_start_offsets[-1] + padded_rows = rows + num_groups * 128 num_col_blocks = ceil_div(cols, 4) padded_cols = num_col_blocks * 4 - output = scales_tensor.new_empty((padded_rows, padded_cols)) + output = scales_tensor.new_zeros((padded_rows, padded_cols)) # Output block stride for the rearranged format BLOCK_ROWS, BLOCK_COLS = 128, 4 @@ -238,7 +262,7 @@ def triton_mx_block_rearrange_2d_M_groups( num_groups, num_col_blocks, ) - triton_scale_swizzle_M_groups[grid]( + wrap_triton(triton_scale_swizzle_M_groups)[grid]( # Input scales scales_tensor.view(torch.uint8), scales_tensor.stride(0), @@ -336,6 +360,7 @@ def triton_scale_swizzle_M_groups( current_start_row += BLOCK_ROWS +@triton_op("torchao::triton_mx_block_rearrange_per_group_3d", mutates_args={}) def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor: """ Rearranges an E8M0 tensor scale to block-scaled swizzle format. @@ -379,7 +404,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch. num_col_blocks, ) - triton_scale_swizzle_per_group_3d[grid]( + wrap_triton(triton_scale_swizzle_per_group_3d)[grid]( scale_tensor.view(torch.uint8), input_stride_dim0, input_stride_dim1, @@ -454,6 +479,7 @@ def triton_scale_swizzle_per_group_3d( ) +@triton_op("torchao::triton_mx_block_rearrange_2d_K_groups", mutates_args={}) def triton_mx_block_rearrange_2d_K_groups( scales_tensor: torch.Tensor, input_group_end_offsets: torch.Tensor, @@ -479,13 +505,13 @@ def triton_mx_block_rearrange_2d_K_groups( ) rows, cols = scales_tensor.shape # Calculate blocks needed - num_groups = input_group_end_offsets.numel() + num_groups = input_group_end_offsets.shape[0] num_row_blocks = ceil_div(rows, 128) padded_rows = num_row_blocks * 128 # output_group_start_offsets always starts with 0 and ends with the total number of cols - padded_cols = output_group_start_offsets[-1] - output = scales_tensor.new_empty((padded_rows, padded_cols)) + padded_cols = cols + num_groups * 4 + output = scales_tensor.new_zeros((padded_rows, padded_cols)) # Output block stride for the rearranged format BLOCK_ROWS, BLOCK_COLS = 128, 4 @@ -497,7 +523,7 @@ def triton_mx_block_rearrange_2d_K_groups( num_groups, num_row_blocks, ) - triton_scale_swizzle_2d_K_groups[grid]( + wrap_triton(triton_scale_swizzle_2d_K_groups)[grid]( # Input scales scales_tensor.view(torch.uint8), scales_tensor.stride(0), diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 996874a42b..bbd86dc5f1 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -58,7 +58,7 @@ def _scaled_grouped_mm( """ # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. if scaling_type == MoEScalingType.FP8_ROWWISE: - logger.info("Using fp8 rowwise for _scaled_grouped_mm") + logger.debug("Using fp8 rowwise for _scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -66,7 +66,7 @@ def _scaled_grouped_mm( out_dtype, ) elif scaling_type == MoEScalingType.MXFP8: - logger.info("Using mxfp8 for _scaled_grouped_mm") + logger.debug("Using mxfp8 for _scaled_grouped_mm") block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow? return _MXFP8GroupedMM.apply( A, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index a861aa6533..0bbbda850e 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -27,7 +27,7 @@ torch.ops.aten.copy_.default, torch.ops.aten.view.default, torch.ops.aten.as_strided.default, - torch.ops.aten._to_copy.default, + torch.ops.aten._to_copy.default, # for *.to(dtype) torch.ops.aten._pin_memory.default, torch.ops.aten.split.Tensor, torch.ops.aten.clone.default, @@ -94,11 +94,11 @@ def __torch_function__(cls, func, types, args, kwargs={}): "B should be a ScaledGroupedMMTensor" ) scaling_type = B.scaling_type - A_is_2d = A.dim() == 2 - B_is_3d = B.dim() == 3 + A_is_2d = A.ndim == 2 + B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None other_args = args[2:] - if A_is_2d and B_is_3d and has_offs: + if A_is_2d and B_is_2d_or_3d and has_offs: return _scaled_grouped_mm( A, B, @@ -125,17 +125,19 @@ def unwrap(t): assert t.scaling_type == scaling_type return t._data - args, kwargs = pytree.tree_map_only( + args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) - assert scaling_type is not None + assert scaling_type is not None, ( + f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments" + ) # detach is special case if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0], scaling_type) + return ScaledGroupedMMTensor(args_unwrapped[0], scaling_type) # perform op - out = func(*args, **kwargs) + out = func(*args_unwrapped, **kwargs_unwrapped) # return regular tensors for ops that don't preserve subclass if func not in _ops_to_preserve_subclass: