Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def linear_weight_int4_impl(
lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)

#################
##################
## linear_qcs4w ##
#################
##################


def linear_qcs4w(
Expand Down Expand Up @@ -234,6 +234,79 @@ def linear_qcs4w(
lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

##################
## linear_q4gsw ##
##################


def unpack_4bit_weight_tensor(
packed_weight_tensor: torch.Tensor, x: torch.Tensor
) -> torch.Tensor:
"""
Reverses the packing performed in quantized_linear.pack_4bit_weight_tensor
"""
# Each packed byte contains two 4-bit values: high nibble and low nibble
K, N_half = packed_weight_tensor.shape
N = N_half * 2

# Unpack high and low nibbles
high_nibble = (packed_weight_tensor >> 4) & 0x0F
low_nibble = packed_weight_tensor & 0x0F

# Stack to shape (K, N)
unpacked = torch.empty(
(K, N), dtype=torch.uint8, device=packed_weight_tensor.device
)
unpacked[:, ::2] = low_nibble
unpacked[:, 1::2] = high_nibble

# Undo the +8 offset and convert to signed 4-bit range [-8, 7]
unpacked = unpacked.to(torch.int8) - 8

in_channels = x.shape[-1]
# Undo any padding that may have been added to input channels
if in_channels != unpacked.shape[-1]:
return unpacked[:, :in_channels]

return unpacked


def linear_q4gsw(
x: torch.Tensor,
weights: torch.Tensor,
weight_scales: torch.Tensor,
group_size: int,
bias: Optional[torch.Tensor] = None,
):
# Unpack the packed weights
weights = unpack_4bit_weight_tensor(weights, x)

# Un-transpose the weight scales
weight_scales = weight_scales.transpose(0, 1)
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)

weights = torch.ops.torchao.dequantize_affine(
weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7
)

out = torch.nn.functional.linear(x, weights)
return out


name = "linear_q4gsw"
lib.define(
f"""
{name}(
Tensor self,
Tensor weights,
Tensor weight_scales,
int group_size,
Tensor? bias = None) -> Tensor
"""
)
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

########################
## linear_qta8a_qga4w ##
########################
Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,10 @@ def register_int8_mm_op():
@update_features(
[
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
exir_ops.edge.et_vk.linear_q4gsw.default,
]
)
def register_qa_qw_linear():
def register_quantized_linear_ops():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_prepacking=True,
Expand Down
125 changes: 49 additions & 76 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,116 +191,94 @@ def find_quantized_linear_patterns(
##


def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor:
def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
"""
Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed
weight tensor by packing 2 4-bit values in one unsigned 8-bit value.
weight tensor by transposing the weight tensor, then packing 2 4-bit values in one
8-bit value.

An input weight tensor of shape (M, K) will produce a packed weight tensor of shape
(M, K / 2).

The packing implemented here is the same as the packing produced by
backends/vulkan/_passes/int4_weight_only_quantizer.py
An input weight tensor of shape (N, K) will produce a packed weight tensor of shape
(K, N / 2).
"""

# Assert we got a properly quantized tensor.
min, max = inp.min().item(), inp.max().item()
min_val, max_val = weight_tensor.min().item(), weight_tensor.max().item()
assert (
max <= 7 and min >= -8
), f"pack_4bit_weight_tensor: [min,max] out of [-8, 7] range, got [{min}, {max}]"
max_val <= 7 and min_val >= -8
), f"pack_4bit_weight_tensor: [min_val,max_val] out of [-8, 7] range, got [{min_val}, {max_val}]"

# Assuming we have a 2d tensor
if inp.ndim != 2:
inp = inp.squeeze()
if weight_tensor.ndim != 2:
weight_tensor = weight_tensor.squeeze()
assert (
inp.ndim == 2
), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {inp.ndim}"
weight_tensor.ndim == 2
), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {weight_tensor.ndim}"

# pad ic
if inp.shape[-1] % 2 != 0:
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
# Need to pad innermost dim to be a multiple of 8, since the minimum load granularity
# is int32 (4 bytes), which contains 8 4-bit values.
if weight_tensor.shape[-1] % 8 != 0:
num_pad = 8 - (weight_tensor.shape[-1] % 8)
weight_tensor = F.pad(input=weight_tensor, pad=(0, num_pad))

# Shape after padding
oc, ic = inp.shape
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
_, in_channels = weight_tensor.shape
assert in_channels % 8 == 0, "convert_to_qc4w: expecting ic to be divisible by 8"

# Adjust inp tensor for zp
inp = inp.to(dtype=torch.uint8) + 8
# Adjust weight_tensor tensor for zp
weight_tensor = weight_tensor.to(dtype=torch.uint8) + 8
# Pack each 4-bit value into a single 8-bit value
return inp[::, ::2] << 4 | inp[::, 1::2]


def make_combined_scales_and_zeros_tensor(
scales: torch.Tensor, zeros: torch.Tensor
) -> torch.Tensor:
"""
Given a scales and zeros tensor, create a combined tensor by stacking them into a
single tensor.

The scales and zeros tensors are expected to be 2D tensors of shape
(OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape
(NUM_GROUPS, OUTPUT_CHANNELS, 2).

This is the scales and zeros format produced by
backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales
and zeros format expected by the _weight_int4pack_mm op in ATen.
"""
scales_reshaped = scales.transpose(0, 1).unsqueeze(2)
zeros_reshaped = zeros.transpose(0, 1).unsqueeze(2)

zeros_scaled = zeros_reshaped * scales_reshaped * -1
return torch.cat((scales_reshaped, zeros_scaled), dim=2)
return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]


##
## Pattern Replacement
##


def make_linear_q4ga_op(
def make_linear_q4gsw_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
weight_scales_tensor: torch.Tensor,
weight_zeros_tensor: torch.Tensor,
):
packed_quantized_weight_tensor = pack_4bit_weight_tensor(weight_tensor)
utils.update_program_state_dict(
ep, match.weight_node.name, packed_quantized_weight_tensor
)
# Need to make sure corresponding FakeTensor has same size
match.weight_node.meta["val"] = match.weight_node.meta["val"][:, ::2].to(
torch.uint8
)

group_size = weight_tensor.shape[1] // weight_scales_tensor.shape[1]

combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor(
weight_scales_tensor, weight_zeros_tensor
num_groups = weight_scales_tensor.shape[-1]
in_channels = weight_tensor.shape[-1]
group_size = in_channels // num_groups

weight_tensor = pack_4bit_weight_tensor(weight_tensor)
# Use this function for convenience to update the state dict with the packed
# weight tensor. Alignment will already have been done in the above function.
weight_tensor = utils.align_width_and_update_state_dict(
ep, match.weight_node, weight_tensor, align_to=1, force_update=True
)

combined_scales_zeros_name = f"{match.weight_node.name}_scales_zeros"
graph_module.register_parameter(
combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor)
# Also transpose the weight scales tensor to shape [num_groups, N]
weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous()
# Align to multiple of 8 to ensure that data loads from the weight scales
# tensor do not go out of bounds. Each thread computes 8 output channels.
utils.align_width_and_update_state_dict(
ep,
match.weight_scales_node,
weight_scales_tensor,
align_to=8,
force_update=True,
)

with graph_module.graph.inserting_before(match.output_node):
combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name)
linear_q4ga_node = graph_module.graph.create_node(
linear_q4gsw_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.linear_weight_int4.default,
exir_ops.edge.et_vk.linear_q4gsw.default,
args=(
match.fp_input_node,
match.weight_node,
match.weight_scales_node,
group_size,
combined_scales_zeros,
1,
),
)

linear_q4ga_node.meta["val"] = match.output_node.meta["val"]
match.output_node.replace_all_uses_with(linear_q4ga_node)
linear_q4gsw_node.meta["val"] = match.output_node.meta["val"]
match.output_node.replace_all_uses_with(linear_q4gsw_node)


def make_linear_q8ta_q8csw_custom_op(
Expand Down Expand Up @@ -373,13 +351,8 @@ def replace_quantized_linear_patterns(
and match.is_weight_pergroup_quantized()
and utils.is_in_4bit_range(weight_tensor)
):
make_linear_q4ga_op(
ep,
graph_module,
match,
weight_tensor,
weight_scales_tensor,
weight_zeros_tensor,
make_linear_q4gsw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_static_per_tensor_quantized()
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,19 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {
return byte;
}

// Extract a 4-bit value from a packed int (little endian)
// It is assumed that the 4-bit value is in the range [0, 15]
int extract_4bit_from_packed_int_le(const int packed, const int col) {
// Extract the 4-bit value from the 8-bit value
int val = packed >> (4 * col) & 0xF;
return val;
}

// Convenience overload for packed uint
int extract_4bit_from_packed_uint_le(const uint packed, const int col) {
// Extract the 4-bit value from the 8-bit value
int val = int(packed >> (4 * col)) & 0xF;
return val;
}

#endif // LINEAR_COMMON_GLSLH
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ void initialize(out FPOutTile out_tile) {
#endif
}

void add(inout FPOutTile out_tile, const FPOutTile other_out_tile) {
#if TILE_M > 1 && TILE_N4 == 1
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
out_tile.data[m][0] += other_out_tile.data[m][0];
}

#else
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
out_tile.data[m][n4] += other_out_tile.data[m][n4];
}
}
#endif
}

#ifdef DEBUG_MODE

void printFPOutTile(const FPOutTile tile) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,14 @@ void apply_scales_and_biases(
#endif
}

void accumulate_out_tile_with_out_tile(
inout FPOutTile accum,
const FPOutTile other) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
accum.data[m][n4] = accum.data[m][n4] + other.data[m][n4];
}
}
}

#endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH
Loading
Loading