From 4d9abba943e29974b14dd4907e3152cf4c254684 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 12 Nov 2024 20:28:07 +0000 Subject: [PATCH 1/6] Snap Co-authored-by: Jambay Kinley --- olive/passes/onnx/mnb_to_qdq.py | 96 +++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index 03cf8aa1f..fc18801f6 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -9,6 +9,8 @@ import numpy as np import onnx +from onnxscript import ir +from onnxscript.rewriter import pattern as orp from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler @@ -62,6 +64,100 @@ def _run_for_config( ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + # 2 Step + # 1. pattern replacement + # 2. Repacking + ir_model = ir.load(model.model_path) + + def mat_mul_n_bits_pattern(op, input_A, qweight, qscales, qzeros, g_idx, bias): + return op.MatMulNBits(input_A, qweight, qscales, qzeros, g_idx, bias) + + def _is_initializer(context, value: ir.Value) -> bool: + graph: ir.Graph = context.graph + return value in graph.initializers.values() + + def mat_mul_n_bits_pattern_check(context, input_A, qweight, qscales, qzeros, g_idx, bias) -> bool: + if not _is_initializer(context, qweight): + return False + node: ir.Node = _get_node(input_A) + block_size = node.attributes["block_size"].value + K = node.attributes["K"].value + g_idx = g_idx.constant_value.numpy() + trivial_g_idx = np.arange(K, dtype=np.int32) // block_size + if not np.array_equal(g_idx, trivial_g_idx): + # Log + return False + return True + + def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bias): + node: ir.Node = _get_node(input_A) + # TODO(justinchuby): Keep the old name of the node + K: int = node.attributes["K"].value + block_size: int = node.attributes["block_size"].value + num_k_blocks = math.ceil(K / block_size) + # will make this a per-axis DQ if num_k_blocks == 1 + # - originally per-axis K == block_size + # - originally blockwise but K <= block_size + is_per_axis = num_k_blocks == 1 + + # dequantizelinear -> transpose -> matmul -> add (optional) + dq = op.DequantizeLinear( + qweight, + qscales, + qzeros, + block_size=None if is_per_axis else block_size, + # for some reason block_wise and per-axis appear to use swapped axis + # flip the axis if it is per-axis + axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0), + ) + # TODO(justinchuby): Improve the way we mark something that needs repacking + dq.producer().meta["needs_repacking"] = True + dq.producer().meta["K"] = K + dq.producer().meta["N"] = node.attributes["N"].value + if config["use_transpose_node"]: + dq = op.Transpose(dq, perm=[1, 0]) + matmul = op.MatMul(input_A, dq) + if bias is not None: + matmul = op.Add(matmul, bias) + return matmul + + replace_matmul_n_bits = orp.RewriteRule( + mat_mul_n_bits_pattern, + mat_mul_n_bits_pattern_check, + mat_mul_n_bits_replacement, + ) + + # Call the rewriter with replace_matmul_n_bits + + # 2. Repacking + for node in ir_model.graph: + if "needs_repacking" not in node.meta: + continue + + # Add Logic handling input 3 + + unpacked_weight_arrays = _unpack_weights( + node.meta["K"], + node.meta["N"], + node.inputs[1].const_value.numpy(), + node.inputs[2].const_value.numpy(), + node.inputs[3].const_value.numpy(), + ) + array = unpacked_weight_arrays[0].view(ml_dtypes.int4) + node.inputs[1].const_value = ir.Tensor(array) + node.inputs[2].const_value = ir.Tensor(array) + input_3 = ir.Value(None) + input_3.const_value = ir.Tensor(array) + node.replace_input_with(3, input_3) + ir_model.graph.initializers[input_3.name] = input_3 + + # TODO(justinchuby): Register and remove initializers + + ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""]) + + # save the model to the output path and return the model + return ir_model_to_olive_model(ir_model, output_model_path, config) + # create a dag from the model dag = OnnxDAG.from_model_path(model.model_path) # remove unnecessary identity nodes From 2215c97721e6c2e4087036a7892d78e8ae9d1338 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 14 Nov 2024 00:07:48 +0000 Subject: [PATCH 2/6] update --- olive/passes/onnx/mnb_to_qdq.py | 304 +++----------------------------- 1 file changed, 20 insertions(+), 284 deletions(-) diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index fc18801f6..e98b0b44c 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -7,8 +7,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict +import ml_dtypes import numpy as np -import onnx from onnxscript import ir from onnxscript.rewriter import pattern as orp @@ -67,7 +67,7 @@ def _run_for_config( # 2 Step # 1. pattern replacement # 2. Repacking - ir_model = ir.load(model.model_path) + ir_model = ir.serde.deserialize_model(model.load_model()) def mat_mul_n_bits_pattern(op, input_A, qweight, qscales, qzeros, g_idx, bias): return op.MatMulNBits(input_A, qweight, qscales, qzeros, g_idx, bias) @@ -92,9 +92,9 @@ def mat_mul_n_bits_pattern_check(context, input_A, qweight, qscales, qzeros, g_i def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bias): node: ir.Node = _get_node(input_A) # TODO(justinchuby): Keep the old name of the node - K: int = node.attributes["K"].value - block_size: int = node.attributes["block_size"].value - num_k_blocks = math.ceil(K / block_size) + k: int = node.attributes["K"].as_int() + block_size: int = node.attributes["block_size"].as_int() + num_k_blocks = math.ceil(k / block_size) # will make this a per-axis DQ if num_k_blocks == 1 # - originally per-axis K == block_size # - originally blockwise but K <= block_size @@ -108,12 +108,12 @@ def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bia block_size=None if is_per_axis else block_size, # for some reason block_wise and per-axis appear to use swapped axis # flip the axis if it is per-axis - axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0), + axis=config["use_transpose_op"] or is_per_axis, ) # TODO(justinchuby): Improve the way we mark something that needs repacking dq.producer().meta["needs_repacking"] = True - dq.producer().meta["K"] = K - dq.producer().meta["N"] = node.attributes["N"].value + dq.producer().meta["K"] = k + dq.producer().meta["N"] = node.attributes["N"].as_int() if config["use_transpose_node"]: dq = op.Transpose(dq, perm=[1, 0]) matmul = op.MatMul(input_A, dq) @@ -127,9 +127,8 @@ def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bia mat_mul_n_bits_replacement, ) - # Call the rewriter with replace_matmul_n_bits - - # 2. Repacking + # TODO(justinchuby): Call the rewriter with replace_matmul_n_bits + # 2. Repack the quantized weights for node in ir_model.graph: if "needs_repacking" not in node.meta: continue @@ -143,272 +142,20 @@ def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bia node.inputs[2].const_value.numpy(), node.inputs[3].const_value.numpy(), ) - array = unpacked_weight_arrays[0].view(ml_dtypes.int4) - node.inputs[1].const_value = ir.Tensor(array) - node.inputs[2].const_value = ir.Tensor(array) - input_3 = ir.Value(None) - input_3.const_value = ir.Tensor(array) - node.replace_input_with(3, input_3) - ir_model.graph.initializers[input_3.name] = input_3 + node.inputs[1].const_value = ir.Tensor(unpacked_weight_arrays[0]) + node.inputs[2].const_value = ir.Tensor(unpacked_weight_arrays[1]) + if len(unpacked_weight_arrays) == 3: + # TODO(justinchuby): Specify a name to input_3 + input_3 = ir.Value(None) + input_3.const_value = ir.Tensor(unpacked_weight_arrays[2]) + # TODO(justinchuby): Ensure the node has three inputs + node.replace_input_with(3, input_3) + ir_model.graph.register_initializer(input_3) # TODO(justinchuby): Register and remove initializers - ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""]) - # save the model to the output path and return the model - return ir_model_to_olive_model(ir_model, output_model_path, config) - - # create a dag from the model - dag = OnnxDAG.from_model_path(model.model_path) - # remove unnecessary identity nodes - dag.remove_identity_nodes() - - # if matmulnbits zero point is the following, then the zero point is not needed in the DQ node - default_mnb_zp = 8 if config["use_int4"] else 0 - int_np_dtype = np.int8 if config["use_int4"] else np.uint8 - int_elem_type = onnx.TensorProto.INT4 if config["use_int4"] else onnx.TensorProto.UINT4 - - num_modified = 0 - for node_name in dag.get_node_names(): - op_type = dag.get_node_op_type(node_name) - if op_type != "MatMulNBits": - continue - - node_inputs = dag.get_node_inputs(node_name) - # only deal with the constant matmul case for now - if not all(dag.is_initializer(i_name) and not dag.is_input(i_name) for i_name in node_inputs[1:]): - continue - - graph_idx = dag.get_graph_idx(node_name) - - # original output proto - node_output = dag.get_node_outputs(node_name)[0] - is_model_output = dag.is_output(node_output) - node_output_proto = None - if dag.get_io(node_output).proto: - node_output_proto = dag.get_io(node_output).proto[-1] - node_attributes = dag.get_node_attributes(node_name) - K = node_attributes["K"] # noqa: N806 - N = node_attributes["N"] # noqa: N806 - block_size = node_attributes["block_size"] - num_k_blocks = math.ceil(K / block_size) - # will make this a per-axis DQ if num_k_blocks == 1 - # - originally per-axis K == block_size - # - originally blockwise but K <= block_size - is_per_axis = num_k_blocks == 1 - - # only deal with 4 bits (int4) for now - if node_attributes["bits"] != 4: - logger.debug("%s uses %d bits, only 4 bits is supported", node_name, node_attributes["bits"]) - continue - - # we can only deal with trivial g_idx, dequantize linear does not support g_idx - if len(node_inputs) >= 5 and node_inputs[4]: - g_idx = dag.get_initializer_np_array(node_inputs[4]) - trivial_g_idx = np.arange(K, dtype=np.int32) // block_size - if not np.array_equal(g_idx, trivial_g_idx): - continue - - # name for the DQ node - dq_name = self._get_new_node_name(dag, node_name, "DequantizeLinear") - # weight, scales, zeros - # (name, new_name, unpacked column size) - quant_inputs = [ - (node_inputs[1], f"{dq_name}.qweight", K), - (node_inputs[2], f"{dq_name}.scales", num_k_blocks), - ] - if len(node_inputs) >= 4 and node_inputs[3]: - quant_inputs.append((node_inputs[3], f"{dq_name}.qzeros", num_k_blocks)) - dq_inputs = [] - - for qi_name, new_qi_name, unpacked_col_size in quant_inputs: - # get the np array - # weight: uint8, scales: float32, zeros: uint8 - qi = dag.get_initializer_np_array(qi_name) - # reshape to 2D - qi = qi.reshape(N, -1) - - # there are cases where unpack and repack is not needed: no transpose + no padding - # but will still do it for simplicity - if qi.dtype == np.uint8: - qi = self._unpack_on_row(qi) - # remove padding if any - qi = qi[:, :unpacked_col_size] - - # Make 1-D scale or qzero if per-axis - if new_qi_name.endswith((".scales", ".qzeros")) and is_per_axis: - qi = qi.flatten() - - # skip if is a no-op zero point - if not config["add_zero_point"] and new_qi_name.endswith(".qzeros") and np.all(qi == default_mnb_zp): - continue - - if not config["use_transpose_op"]: - # becomes K X N - qi = qi.T - - if qi.dtype == np.uint8: - if config["use_int4"]: - # no worries about making signed since the values only use 4 bits - qi = qi.astype(np.int8) - # subtract 8 to make it signed - # no worries here again since the values are in the range 0-15 and numpy uses 2's complement - qi -= 8 - - # pack in the format expected by onnx and create the tensor - tensor = onnx.helper.make_tensor( - new_qi_name, - int_elem_type, - qi.shape, - self._pack_on_flat(qi).tobytes(), - raw=True, - ) - else: - tensor = onnx.numpy_helper.from_array(qi, name=new_qi_name) - - # add the initializer - dag.add_initializer(tensor, graph_idx) - # add the input name - dq_inputs.append(new_qi_name) - # DQ default zp is 0 but MatMulNBits is 8, so we need to add a zero tensor with all 8s - # no need to add for int4 if add_zero_point is False - if len(dq_inputs) == 2 and (config["add_zero_point"] or not config["use_int4"]): - zp_name = f"{dq_name}.qzeros" - zp_shape = ( - [N] if is_per_axis else ([N, num_k_blocks] if config["use_transpose_op"] else [num_k_blocks, N]) - ) - zp_tensor = onnx.helper.make_tensor( - zp_name, - int_elem_type, - zp_shape, - # no zp in matmulnbits is equivalent to 8 uint4 and 0 int4 in DQ - self._pack_on_flat(np.zeros(N * num_k_blocks, dtype=int_np_dtype) + 8 - default_mnb_zp).tobytes(), - raw=True, - ) - dag.add_initializer(zp_tensor, graph_idx) - dq_inputs.append(zp_name) - - # onnx dtype for the float tensors (scale, dequantized weight, matmul inputs+outputs) - float_elem_type = onnx.helper.np_dtype_to_tensor_dtype(dag.get_initializer_np_array(node_inputs[2]).dtype) - - # new nodes and value infos to add to the graph - # ensure that the node names and output names are unique - # will add the new nodes, make consumers use the new output and remove the node - # if output is a model output, rename it back to the original name - new_nodes = [] - new_value_infos = [] - - # DequantizeLinear - dq_name = self._get_new_node_name(dag, node_name, "DequantizeLinear") - dq_output = f"{dq_name}/output_0" - new_nodes.append( - onnx.helper.make_node( - "DequantizeLinear", - dq_inputs, - [dq_output], - name=dq_name, - block_size=None if is_per_axis else block_size, - # for some reason block_wise and per-axis appear to use swapped axis - # flip the axis if it is per-axis - axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0), - ) - ) - new_value_infos.append( - onnx.helper.make_tensor_value_info( - dq_output, float_elem_type, shape=[N, K] if config["use_transpose_op"] else [K, N] - ) - ) - - if config["use_transpose_op"]: - # Transpose - transpose_name = self._get_new_node_name(dag, node_name, "Transpose") - transpose_output = f"{transpose_name}/output_0" - new_nodes.append( - onnx.helper.make_node( - "Transpose", [dq_output], [transpose_output], name=transpose_name, perm=[1, 0] - ) - ) - new_value_infos.append( - onnx.helper.make_tensor_value_info(transpose_output, float_elem_type, shape=[K, N]) - ) - matmul_input = transpose_output - else: - matmul_input = dq_output - - # MatMul - matmul_name = self._get_new_node_name(dag, node_name, "MatMul") - matmul_output = f"{matmul_name}/output_0" - new_nodes.append( - onnx.helper.make_node("MatMul", [node_inputs[0], matmul_input], [matmul_output], name=matmul_name) - ) - if node_output_proto: - # the output shape is the same as the original MatMulNBits node - matmul_output_proto = onnx.ValueInfoProto() - matmul_output_proto.CopyFrom(node_output_proto) - matmul_output_proto.name = matmul_output - new_value_infos.append(matmul_output_proto) - final_name = matmul_name - final_output = matmul_output - - if len(node_inputs) >= 5 and node_inputs[4]: - # Bias Add - # it has bias - bias_i_name = node_inputs[4] - new_bias_i_name = bias_i_name.replace("MatMulNBits", "MatMul") - bias_initiaizer = onnx.numpy_helper.from_array( - dag.get_initializer_np_array(bias_i_name), name=new_bias_i_name - ) - dag.add_initializer(bias_initiaizer, graph_idx) - - bias_name = self._get_new_node_name(dag, node_name, "Add") - bias_output = f"{bias_name}/output_0" - new_nodes.append( - onnx.helper.make_node("Add", [matmul_output, new_bias_i_name], [bias_output], name=bias_name) - ) - if node_output_proto: - # the output shape is the same as the original MatMulNBits node - bias_output_proto = onnx.ValueInfoProto() - bias_output_proto.CopyFrom(node_output_proto) - bias_output_proto.name = bias_output - new_value_infos.append(bias_output_proto) - final_name = bias_name - final_output = bias_output - - for node in new_nodes: - dag.add_node(node, graph_idx) - - # change the input of the consumers - for consumer in dag.get_consumers(node_name): - dag.replace_node_input(consumer, node_output, final_output) - - # add the new value infos - for vi in new_value_infos: - dag.add_value_info(vi, graph_idx) - - # remove the node - if is_model_output: - dag.remove_output(node_output) - dag.remove_node(node_name) - - # rename to original name if it is a model output - if is_model_output: - dag.rename_node_output(final_name, final_output, node_output) - dag.make_output(node_output) - - num_modified += 1 - - if num_modified == 0: - logger.info("No MatMulNBits nodes found. Returning the original model.") - return model - - dag.update() - logger.debug("Modified %d MatMulNBits nodes", num_modified) - # this might not work for all models but will just update the opset version to 21 - # if there is an issue, try the logic in OnnxOpVersionConversion - dag.model.opset_import[0].version = max(21, dag.model.opset_import[0].version) - - # save the model to the output path and return the model - return model_proto_to_olive_model(dag.model, output_model_path, config) + return model_proto_to_olive_model(ir.serde.serialize_model(ir_model), output_model_path, config) @staticmethod def _get_new_node_name(dag: OnnxDAG, old_name: str, op_type: str): @@ -442,14 +189,3 @@ def _unpack_on_row(tensor: "NDArray") -> "NDArray": # mask out the first 4 bits tensor &= 0xF return tensor.reshape(tensor.shape[0], -1) - - @staticmethod - def _pack_on_flat(tensor: "NDArray") -> "NDArray": - """Pack two uint4 into one uint8 on a flattened tensor.""" - tensor = tensor.flatten() - - if len(tensor) % 2: - tensor = np.pad(tensor, (0, 1), mode="constant") - - # right 4 bits are the first uint4 - return (tensor[0::2] & 0xF) | ((tensor[1::2] & 0xF) << 4) From 1e1d0f29c6339b4de79d15b71badbd1dc21d035e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 14 Nov 2024 00:38:27 +0000 Subject: [PATCH 3/6] update --- olive/passes/onnx/mnb_to_qdq.py | 60 +++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index e98b0b44c..f8105c948 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -69,28 +69,49 @@ def _run_for_config( # 2. Repacking ir_model = ir.serde.deserialize_model(model.load_model()) - def mat_mul_n_bits_pattern(op, input_A, qweight, qscales, qzeros, g_idx, bias): - return op.MatMulNBits(input_A, qweight, qscales, qzeros, g_idx, bias) + def mat_mul_n_bits_pattern(op, input_A, q_weight, q_scales, q_zeros, g_idx, bias): + # bias is an optional input + return op.MatMulNBits( + input_A, + q_weight, + q_scales, + q_zeros, + g_idx, + bias, + _outputs=["mat_mul_n_bits_out"], # Bind the output to the name "mat_mul_n_bits_out" + ) def _is_initializer(context, value: ir.Value) -> bool: graph: ir.Graph = context.graph return value in graph.initializers.values() - def mat_mul_n_bits_pattern_check(context, input_A, qweight, qscales, qzeros, g_idx, bias) -> bool: - if not _is_initializer(context, qweight): + def mat_mul_n_bits_pattern_check(context, *, q_weight, g_idx, mat_mul_n_bits_out: ir.Value, **_) -> bool: + if not _is_initializer(context, q_weight): + return False + node: ir.Node = mat_mul_n_bits_out.producer() + block_size = node.attributes["block_size"].as_int() + k = node.attributes["K"].as_int() + if not _is_initializer(g_idx, q_weight): return False - node: ir.Node = _get_node(input_A) - block_size = node.attributes["block_size"].value - K = node.attributes["K"].value g_idx = g_idx.constant_value.numpy() - trivial_g_idx = np.arange(K, dtype=np.int32) // block_size + trivial_g_idx = np.arange(k, dtype=np.int32) // block_size if not np.array_equal(g_idx, trivial_g_idx): - # Log + # TODO: We can log why the pattern is not matched here return False return True - def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bias): - node: ir.Node = _get_node(input_A) + def mat_mul_n_bits_replacement( + op, + *, + input_A: ir.Value, + q_weight: ir.Value, + q_scales: ir.Value, + q_zeros: ir.Value, + bias: ir.Value, + mat_mul_n_bits_out: ir.Value, + **_, + ): + node: ir.Node = mat_mul_n_bits_out.producer() # TODO(justinchuby): Keep the old name of the node k: int = node.attributes["K"].as_int() block_size: int = node.attributes["block_size"].as_int() @@ -100,11 +121,11 @@ def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bia # - originally blockwise but K <= block_size is_per_axis = num_k_blocks == 1 - # dequantizelinear -> transpose -> matmul -> add (optional) + # DequantizeLinear -> Transpose -> MatMul -> Add (optional) dq = op.DequantizeLinear( - qweight, - qscales, - qzeros, + q_weight, + q_scales, + q_zeros, block_size=None if is_per_axis else block_size, # for some reason block_wise and per-axis appear to use swapped axis # flip the axis if it is per-axis @@ -121,13 +142,13 @@ def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bia matmul = op.Add(matmul, bias) return matmul - replace_matmul_n_bits = orp.RewriteRule( + replace_mat_mul_n_bits = orp.RewriteRule( mat_mul_n_bits_pattern, mat_mul_n_bits_pattern_check, mat_mul_n_bits_replacement, ) + # TODO(justinchuby): Call the rewriter with replace_mat_mul_n_bits - # TODO(justinchuby): Call the rewriter with replace_matmul_n_bits # 2. Repack the quantized weights for node in ir_model.graph: if "needs_repacking" not in node.meta: @@ -152,6 +173,11 @@ def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bia node.replace_input_with(3, input_3) ir_model.graph.register_initializer(input_3) + # Clear the meta data + del node.meta["needs_repacking"] + del node.meta["K"] + del node.meta["N"] + # TODO(justinchuby): Register and remove initializers ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""]) From 48250820f8dc6bd943d8edb07566dbabe86b5e8f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 14 Nov 2024 00:41:12 +0000 Subject: [PATCH 4/6] update --- olive/passes/onnx/mnb_to_qdq.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index f8105c948..f128f1050 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -111,10 +111,10 @@ def mat_mul_n_bits_replacement( mat_mul_n_bits_out: ir.Value, **_, ): - node: ir.Node = mat_mul_n_bits_out.producer() + mat_mul_n_bits: ir.Node = mat_mul_n_bits_out.producer() # TODO(justinchuby): Keep the old name of the node - k: int = node.attributes["K"].as_int() - block_size: int = node.attributes["block_size"].as_int() + k: int = mat_mul_n_bits.attributes["K"].as_int() + block_size: int = mat_mul_n_bits.attributes["block_size"].as_int() num_k_blocks = math.ceil(k / block_size) # will make this a per-axis DQ if num_k_blocks == 1 # - originally per-axis K == block_size @@ -122,7 +122,7 @@ def mat_mul_n_bits_replacement( is_per_axis = num_k_blocks == 1 # DequantizeLinear -> Transpose -> MatMul -> Add (optional) - dq = op.DequantizeLinear( + dq: ir.Value = op.DequantizeLinear( q_weight, q_scales, q_zeros, @@ -132,9 +132,10 @@ def mat_mul_n_bits_replacement( axis=config["use_transpose_op"] or is_per_axis, ) # TODO(justinchuby): Improve the way we mark something that needs repacking - dq.producer().meta["needs_repacking"] = True - dq.producer().meta["K"] = k - dq.producer().meta["N"] = node.attributes["N"].as_int() + dq_node = dq.producer() + dq_node.meta["needs_repacking"] = True + dq_node.meta["K"] = k + dq_node.meta["N"] = mat_mul_n_bits.attributes["N"].as_int() if config["use_transpose_node"]: dq = op.Transpose(dq, perm=[1, 0]) matmul = op.MatMul(input_A, dq) From e694d1d3850a5c0a09e52a22ee5279e8d69284f7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 14 Nov 2024 00:48:18 +0000 Subject: [PATCH 5/6] xor --- olive/passes/onnx/mnb_to_qdq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index f128f1050..1be681fb8 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -129,7 +129,7 @@ def mat_mul_n_bits_replacement( block_size=None if is_per_axis else block_size, # for some reason block_wise and per-axis appear to use swapped axis # flip the axis if it is per-axis - axis=config["use_transpose_op"] or is_per_axis, + axis=config["use_transpose_op"] ^ is_per_axis, ) # TODO(justinchuby): Improve the way we mark something that needs repacking dq_node = dq.producer() From b39d69eedc253c7b984486fae78d813f9eefee42 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 14 Nov 2024 18:02:46 +0000 Subject: [PATCH 6/6] traversal --- olive/passes/onnx/mnb_to_qdq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index 1be681fb8..d5818fc65 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -151,7 +151,7 @@ def mat_mul_n_bits_replacement( # TODO(justinchuby): Call the rewriter with replace_mat_mul_n_bits # 2. Repack the quantized weights - for node in ir_model.graph: + for node in ir.traversal.RecursiveGraphIterator(ir_model.graph): if "needs_repacking" not in node.meta: continue