Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchao/optim/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor
try:
from torch.distributed.tensor import DTensor
except Exception:
try:
from torch.distributed._tensor import DTensor
except Exception:
DTensor = tuple()


# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
Expand Down Expand Up @@ -117,7 +124,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
return out.view(codes.shape)


def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor:
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
Expand All @@ -127,6 +134,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
is_dt = isinstance(_x_f32, DTensor)
x_f32 = _x_f32.to_local() if is_dt else _x_f32

rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
Expand All @@ -142,4 +152,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
)
# alternative, slightly faster
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
return x_f32_bits.view(torch.float32).bfloat16()
x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16()

return DTensor.from_local(
x_bf16_trunc, _x_f32.device_mesh, _x_f32.placements,
run_check=False, shape=tuple(_x_f32.shape), stride=tuple(_x_f32.stride()),
) if is_dt else x_bf16_trunc
Loading