Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,28 @@ def get_activation_quantizer(
)
elif qa_mode == "dorefa":
act_quantizer = dorefa_quantize_activation
elif (
qa_mode == "max"
): # NOTE Need to be careful using this for activation, particular to 1 sided.
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "minmax":
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)

elif "max" in qa_mode:
# NOTE Need to be careful using this for activation, particular to 1 sided.
if "min" in qa_mode:
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
elif "pertoken" in qa_mode or "perToken" in qa_mode:
act_quantizer = QMaxDynamic(nbits, dim=-1)
elif "per_channel" in qa_mode or "perCh" in qa_mode:
act_quantizer = QMaxDynamic(nbits, dim=-2)
elif "sym" in qa_mode:
act_quantizer = Qmax(
nbits,
align_zero=True,
minmax=False,
extend_act_range=extend_act_range,
)
else:
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "fix":
act_quantizer = QFixSymmetric(
nbits, init_clip_val=clip_val, align_zero=align_zero
)
elif qa_mode == "maxsym":
act_quantizer = Qmax(
nbits,
align_zero=True,
minmax=False,
extend_act_range=extend_act_range,
)
elif qa_mode == "pactsym":
act_quantizer = PACT2Sym(
nbits,
Expand Down Expand Up @@ -179,8 +184,6 @@ def get_activation_quantizer(
perToken=perToken,
emulate=True,
)
elif qa_mode == "pertokenmax":
act_quantizer = PerTokenMax(nbits)
else:
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
else: # swcap-compatible activation quantizers
Expand Down Expand Up @@ -3488,6 +3491,42 @@ def __repr__(self):
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"


class QMaxDynamic(nn.Module):
def __init__(self, num_bits, dim=-1):
"""
For per-token or per-channel quantization using abs().max() as scale, usually for activation
and could be used for Qbmm M2 as well.
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
dim = -2 -> per-channel
Zero is aligned so that the levels are symmetric around zero (lossing one level)
Since the token length is un-known before running, the quantizater can only calculate the
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
(unless input seq length is always the same, not just padded to a fixed length.)
"""
super().__init__()
self.num_bits = num_bits
self.levels = 2 ** (self.num_bits - 1) - 1
if isinstance(dim, str):
if "perCh" in dim or "per_channel" in dim:
dim = -2
elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
dim = -1
elif dim in [-1, -2]:
self.reduce_dim = dim
else:
raise ValueError(
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
)

def forward(self, input_tensor):
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
scales = amax_dim.clamp(min=1e-5).div(self.levels)
return input_tensor.div(scales).round().mul(scales)

def __repr__(self):
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"


class Qdynamic(nn.Module):
def __init__(
self,
Expand Down
Loading