diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index b0edfc7fc5..b6a0f54662 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -119,6 +119,45 @@ def test_bf16_stochastic_round(self, device, compile): # must cast BF16 tensor back to FP32 so that .mean() is accurate torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5) + @parametrize("device", _DEVICES) + @parametrize("compile", [False, True]) + def test_bf16_stochastic_round_dtensor(self, device, compile): + pytest.importorskip("torch.distributed") + import torch.distributed as dist + from torch.distributed.tensor import DTensor, Replicate + from torch.distributed.device_mesh import init_device_mesh + + created_pg = False + if dist.is_available() and not dist.is_initialized(): + store = dist.TCPStore("127.0.0.1", 29500, 1, True) + dist.init_process_group( + backend="gloo", + store=store, + rank=0, + world_size=1, + ) + created_pg = True + + try: + torch.manual_seed(common_utils.SEED) + x = torch.rand(32, device=device) * 100 + x_rep = x.view(-1, 1).repeat(1, 100_000) + + func = torch.compile( + _fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile + ) + out_plain = func(x_rep) + + mesh = init_device_mesh(device, (1,)) + x_dt = DTensor.from_local(x_rep, mesh, [Replicate()], run_check=False) + out_dt = func(x_dt) + + assert isinstance(out_dt, DTensor) + torch.testing.assert_close(out_dt.to_local(), out_plain) + finally: + if created_pg: + dist.destroy_process_group() + class TestOptim(TestCase): @parametrize( diff --git a/torchao/optim/quant_utils.py b/torchao/optim/quant_utils.py index a4035fde1c..bef297545d 100644 --- a/torchao/optim/quant_utils.py +++ b/torchao/optim/quant_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch from torch import Tensor +from torch.distributed.tensor import DTensor # https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391 @@ -117,7 +118,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor): return out.view(codes.shape) -def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: +def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor: # For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16 # - Round towards zero: [a31, ..., a16, 0, ..., 0] # - Round away from zero: [a31, ..., a16+1, 0, ..., 0] @@ -127,6 +128,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16 # # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16 + is_dt = isinstance(_x_f32, DTensor) + x_f32 = _x_f32.to_local() if is_dt else _x_f32 + rand_16bit = torch.randint( 0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32 ) @@ -142,4 +146,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: ) # alternative, slightly faster # x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000 - return x_f32_bits.view(torch.float32).bfloat16() + x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16() + + return DTensor.from_local( + x_bf16_trunc, _x_f32.device_mesh, _x_f32.placements, + run_check=False, shape=tuple(_x_f32.shape), stride=tuple(_x_f32.stride()), + ) if is_dt else x_bf16_trunc