diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index 03cf8aa1f..d5818fc65 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -7,8 +7,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict +import ml_dtypes import numpy as np -import onnx +from onnxscript import ir +from onnxscript.rewriter import pattern as orp from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler @@ -62,257 +64,125 @@ def _run_for_config( ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - # create a dag from the model - dag = OnnxDAG.from_model_path(model.model_path) - # remove unnecessary identity nodes - dag.remove_identity_nodes() - - # if matmulnbits zero point is the following, then the zero point is not needed in the DQ node - default_mnb_zp = 8 if config["use_int4"] else 0 - int_np_dtype = np.int8 if config["use_int4"] else np.uint8 - int_elem_type = onnx.TensorProto.INT4 if config["use_int4"] else onnx.TensorProto.UINT4 - - num_modified = 0 - for node_name in dag.get_node_names(): - op_type = dag.get_node_op_type(node_name) - if op_type != "MatMulNBits": - continue - - node_inputs = dag.get_node_inputs(node_name) - # only deal with the constant matmul case for now - if not all(dag.is_initializer(i_name) and not dag.is_input(i_name) for i_name in node_inputs[1:]): - continue - - graph_idx = dag.get_graph_idx(node_name) + # 2 Step + # 1. pattern replacement + # 2. Repacking + ir_model = ir.serde.deserialize_model(model.load_model()) + + def mat_mul_n_bits_pattern(op, input_A, q_weight, q_scales, q_zeros, g_idx, bias): + # bias is an optional input + return op.MatMulNBits( + input_A, + q_weight, + q_scales, + q_zeros, + g_idx, + bias, + _outputs=["mat_mul_n_bits_out"], # Bind the output to the name "mat_mul_n_bits_out" + ) - # original output proto - node_output = dag.get_node_outputs(node_name)[0] - is_model_output = dag.is_output(node_output) - node_output_proto = None - if dag.get_io(node_output).proto: - node_output_proto = dag.get_io(node_output).proto[-1] - node_attributes = dag.get_node_attributes(node_name) - K = node_attributes["K"] # noqa: N806 - N = node_attributes["N"] # noqa: N806 - block_size = node_attributes["block_size"] - num_k_blocks = math.ceil(K / block_size) + def _is_initializer(context, value: ir.Value) -> bool: + graph: ir.Graph = context.graph + return value in graph.initializers.values() + + def mat_mul_n_bits_pattern_check(context, *, q_weight, g_idx, mat_mul_n_bits_out: ir.Value, **_) -> bool: + if not _is_initializer(context, q_weight): + return False + node: ir.Node = mat_mul_n_bits_out.producer() + block_size = node.attributes["block_size"].as_int() + k = node.attributes["K"].as_int() + if not _is_initializer(g_idx, q_weight): + return False + g_idx = g_idx.constant_value.numpy() + trivial_g_idx = np.arange(k, dtype=np.int32) // block_size + if not np.array_equal(g_idx, trivial_g_idx): + # TODO: We can log why the pattern is not matched here + return False + return True + + def mat_mul_n_bits_replacement( + op, + *, + input_A: ir.Value, + q_weight: ir.Value, + q_scales: ir.Value, + q_zeros: ir.Value, + bias: ir.Value, + mat_mul_n_bits_out: ir.Value, + **_, + ): + mat_mul_n_bits: ir.Node = mat_mul_n_bits_out.producer() + # TODO(justinchuby): Keep the old name of the node + k: int = mat_mul_n_bits.attributes["K"].as_int() + block_size: int = mat_mul_n_bits.attributes["block_size"].as_int() + num_k_blocks = math.ceil(k / block_size) # will make this a per-axis DQ if num_k_blocks == 1 # - originally per-axis K == block_size # - originally blockwise but K <= block_size is_per_axis = num_k_blocks == 1 - # only deal with 4 bits (int4) for now - if node_attributes["bits"] != 4: - logger.debug("%s uses %d bits, only 4 bits is supported", node_name, node_attributes["bits"]) - continue - - # we can only deal with trivial g_idx, dequantize linear does not support g_idx - if len(node_inputs) >= 5 and node_inputs[4]: - g_idx = dag.get_initializer_np_array(node_inputs[4]) - trivial_g_idx = np.arange(K, dtype=np.int32) // block_size - if not np.array_equal(g_idx, trivial_g_idx): - continue - - # name for the DQ node - dq_name = self._get_new_node_name(dag, node_name, "DequantizeLinear") - # weight, scales, zeros - # (name, new_name, unpacked column size) - quant_inputs = [ - (node_inputs[1], f"{dq_name}.qweight", K), - (node_inputs[2], f"{dq_name}.scales", num_k_blocks), - ] - if len(node_inputs) >= 4 and node_inputs[3]: - quant_inputs.append((node_inputs[3], f"{dq_name}.qzeros", num_k_blocks)) - dq_inputs = [] - - for qi_name, new_qi_name, unpacked_col_size in quant_inputs: - # get the np array - # weight: uint8, scales: float32, zeros: uint8 - qi = dag.get_initializer_np_array(qi_name) - # reshape to 2D - qi = qi.reshape(N, -1) - - # there are cases where unpack and repack is not needed: no transpose + no padding - # but will still do it for simplicity - if qi.dtype == np.uint8: - qi = self._unpack_on_row(qi) - # remove padding if any - qi = qi[:, :unpacked_col_size] - - # Make 1-D scale or qzero if per-axis - if new_qi_name.endswith((".scales", ".qzeros")) and is_per_axis: - qi = qi.flatten() - - # skip if is a no-op zero point - if not config["add_zero_point"] and new_qi_name.endswith(".qzeros") and np.all(qi == default_mnb_zp): - continue - - if not config["use_transpose_op"]: - # becomes K X N - qi = qi.T - - if qi.dtype == np.uint8: - if config["use_int4"]: - # no worries about making signed since the values only use 4 bits - qi = qi.astype(np.int8) - # subtract 8 to make it signed - # no worries here again since the values are in the range 0-15 and numpy uses 2's complement - qi -= 8 - - # pack in the format expected by onnx and create the tensor - tensor = onnx.helper.make_tensor( - new_qi_name, - int_elem_type, - qi.shape, - self._pack_on_flat(qi).tobytes(), - raw=True, - ) - else: - tensor = onnx.numpy_helper.from_array(qi, name=new_qi_name) - - # add the initializer - dag.add_initializer(tensor, graph_idx) - # add the input name - dq_inputs.append(new_qi_name) - # DQ default zp is 0 but MatMulNBits is 8, so we need to add a zero tensor with all 8s - # no need to add for int4 if add_zero_point is False - if len(dq_inputs) == 2 and (config["add_zero_point"] or not config["use_int4"]): - zp_name = f"{dq_name}.qzeros" - zp_shape = ( - [N] if is_per_axis else ([N, num_k_blocks] if config["use_transpose_op"] else [num_k_blocks, N]) - ) - zp_tensor = onnx.helper.make_tensor( - zp_name, - int_elem_type, - zp_shape, - # no zp in matmulnbits is equivalent to 8 uint4 and 0 int4 in DQ - self._pack_on_flat(np.zeros(N * num_k_blocks, dtype=int_np_dtype) + 8 - default_mnb_zp).tobytes(), - raw=True, - ) - dag.add_initializer(zp_tensor, graph_idx) - dq_inputs.append(zp_name) - - # onnx dtype for the float tensors (scale, dequantized weight, matmul inputs+outputs) - float_elem_type = onnx.helper.np_dtype_to_tensor_dtype(dag.get_initializer_np_array(node_inputs[2]).dtype) - - # new nodes and value infos to add to the graph - # ensure that the node names and output names are unique - # will add the new nodes, make consumers use the new output and remove the node - # if output is a model output, rename it back to the original name - new_nodes = [] - new_value_infos = [] - - # DequantizeLinear - dq_name = self._get_new_node_name(dag, node_name, "DequantizeLinear") - dq_output = f"{dq_name}/output_0" - new_nodes.append( - onnx.helper.make_node( - "DequantizeLinear", - dq_inputs, - [dq_output], - name=dq_name, - block_size=None if is_per_axis else block_size, - # for some reason block_wise and per-axis appear to use swapped axis - # flip the axis if it is per-axis - axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0), - ) - ) - new_value_infos.append( - onnx.helper.make_tensor_value_info( - dq_output, float_elem_type, shape=[N, K] if config["use_transpose_op"] else [K, N] - ) + # DequantizeLinear -> Transpose -> MatMul -> Add (optional) + dq: ir.Value = op.DequantizeLinear( + q_weight, + q_scales, + q_zeros, + block_size=None if is_per_axis else block_size, + # for some reason block_wise and per-axis appear to use swapped axis + # flip the axis if it is per-axis + axis=config["use_transpose_op"] ^ is_per_axis, ) + # TODO(justinchuby): Improve the way we mark something that needs repacking + dq_node = dq.producer() + dq_node.meta["needs_repacking"] = True + dq_node.meta["K"] = k + dq_node.meta["N"] = mat_mul_n_bits.attributes["N"].as_int() + if config["use_transpose_node"]: + dq = op.Transpose(dq, perm=[1, 0]) + matmul = op.MatMul(input_A, dq) + if bias is not None: + matmul = op.Add(matmul, bias) + return matmul + + replace_mat_mul_n_bits = orp.RewriteRule( + mat_mul_n_bits_pattern, + mat_mul_n_bits_pattern_check, + mat_mul_n_bits_replacement, + ) + # TODO(justinchuby): Call the rewriter with replace_mat_mul_n_bits + + # 2. Repack the quantized weights + for node in ir.traversal.RecursiveGraphIterator(ir_model.graph): + if "needs_repacking" not in node.meta: + continue - if config["use_transpose_op"]: - # Transpose - transpose_name = self._get_new_node_name(dag, node_name, "Transpose") - transpose_output = f"{transpose_name}/output_0" - new_nodes.append( - onnx.helper.make_node( - "Transpose", [dq_output], [transpose_output], name=transpose_name, perm=[1, 0] - ) - ) - new_value_infos.append( - onnx.helper.make_tensor_value_info(transpose_output, float_elem_type, shape=[K, N]) - ) - matmul_input = transpose_output - else: - matmul_input = dq_output + # Add Logic handling input 3 - # MatMul - matmul_name = self._get_new_node_name(dag, node_name, "MatMul") - matmul_output = f"{matmul_name}/output_0" - new_nodes.append( - onnx.helper.make_node("MatMul", [node_inputs[0], matmul_input], [matmul_output], name=matmul_name) + unpacked_weight_arrays = _unpack_weights( + node.meta["K"], + node.meta["N"], + node.inputs[1].const_value.numpy(), + node.inputs[2].const_value.numpy(), + node.inputs[3].const_value.numpy(), ) - if node_output_proto: - # the output shape is the same as the original MatMulNBits node - matmul_output_proto = onnx.ValueInfoProto() - matmul_output_proto.CopyFrom(node_output_proto) - matmul_output_proto.name = matmul_output - new_value_infos.append(matmul_output_proto) - final_name = matmul_name - final_output = matmul_output - - if len(node_inputs) >= 5 and node_inputs[4]: - # Bias Add - # it has bias - bias_i_name = node_inputs[4] - new_bias_i_name = bias_i_name.replace("MatMulNBits", "MatMul") - bias_initiaizer = onnx.numpy_helper.from_array( - dag.get_initializer_np_array(bias_i_name), name=new_bias_i_name - ) - dag.add_initializer(bias_initiaizer, graph_idx) - - bias_name = self._get_new_node_name(dag, node_name, "Add") - bias_output = f"{bias_name}/output_0" - new_nodes.append( - onnx.helper.make_node("Add", [matmul_output, new_bias_i_name], [bias_output], name=bias_name) - ) - if node_output_proto: - # the output shape is the same as the original MatMulNBits node - bias_output_proto = onnx.ValueInfoProto() - bias_output_proto.CopyFrom(node_output_proto) - bias_output_proto.name = bias_output - new_value_infos.append(bias_output_proto) - final_name = bias_name - final_output = bias_output - - for node in new_nodes: - dag.add_node(node, graph_idx) - - # change the input of the consumers - for consumer in dag.get_consumers(node_name): - dag.replace_node_input(consumer, node_output, final_output) - - # add the new value infos - for vi in new_value_infos: - dag.add_value_info(vi, graph_idx) - - # remove the node - if is_model_output: - dag.remove_output(node_output) - dag.remove_node(node_name) - - # rename to original name if it is a model output - if is_model_output: - dag.rename_node_output(final_name, final_output, node_output) - dag.make_output(node_output) - - num_modified += 1 - - if num_modified == 0: - logger.info("No MatMulNBits nodes found. Returning the original model.") - return model - - dag.update() - logger.debug("Modified %d MatMulNBits nodes", num_modified) - # this might not work for all models but will just update the opset version to 21 - # if there is an issue, try the logic in OnnxOpVersionConversion - dag.model.opset_import[0].version = max(21, dag.model.opset_import[0].version) - - # save the model to the output path and return the model - return model_proto_to_olive_model(dag.model, output_model_path, config) + node.inputs[1].const_value = ir.Tensor(unpacked_weight_arrays[0]) + node.inputs[2].const_value = ir.Tensor(unpacked_weight_arrays[1]) + if len(unpacked_weight_arrays) == 3: + # TODO(justinchuby): Specify a name to input_3 + input_3 = ir.Value(None) + input_3.const_value = ir.Tensor(unpacked_weight_arrays[2]) + # TODO(justinchuby): Ensure the node has three inputs + node.replace_input_with(3, input_3) + ir_model.graph.register_initializer(input_3) + + # Clear the meta data + del node.meta["needs_repacking"] + del node.meta["K"] + del node.meta["N"] + + # TODO(justinchuby): Register and remove initializers + ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""]) + + return model_proto_to_olive_model(ir.serde.serialize_model(ir_model), output_model_path, config) @staticmethod def _get_new_node_name(dag: OnnxDAG, old_name: str, op_type: str): @@ -346,14 +216,3 @@ def _unpack_on_row(tensor: "NDArray") -> "NDArray": # mask out the first 4 bits tensor &= 0xF return tensor.reshape(tensor.shape[0], -1) - - @staticmethod - def _pack_on_flat(tensor: "NDArray") -> "NDArray": - """Pack two uint4 into one uint8 on a flattened tensor.""" - tensor = tensor.flatten() - - if len(tensor) % 2: - tensor = np.pad(tensor, (0, 1), mode="constant") - - # right 4 bits are the first uint4 - return (tensor[0::2] & 0xF) | ((tensor[1::2] & 0xF) << 4)