Skip to content

Commit a146183

Browse files
committed
fix tuning bug
Signed-off-by: Mengni Wang <[email protected]>
1 parent 21ff181 commit a146183

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

auto_round/compressors/base.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,10 +2572,9 @@ def _quantize_layer(
25722572
whole_indices = torch.randperm(nsamples)[:pick_samples]
25732573
if gradient_accumulate_steps != 1:
25742574
if q_inputs is not None:
2575-
current_input = [q_inputs[i] for i in whole_indices]
2575+
num_elm = self._get_current_num_elm(q_input_ids, whole_indices)
25762576
else:
2577-
current_input = [inputs[i] for i in whole_indices]
2578-
num_elm = sum(id.numel() for id in current_input)
2577+
num_elm = self._get_current_num_elm(inputs, whole_indices)
25792578
for tmp_step in range(gradient_accumulate_steps):
25802579
indices = whole_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size]
25812580
if q_inputs is not None:
@@ -2700,6 +2699,14 @@ def _get_current_q_output(
27002699
output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
27012700
return output_q
27022701

2702+
def _get_current_num_elm(
2703+
self,
2704+
input_ids: list[torch.Tensor],
2705+
indices: list[int],
2706+
) -> int:
2707+
current_input_ids = [input_ids[i] for i in indices]
2708+
return sum(id.numel() for id in current_input_ids)
2709+
27032710
def _quantize_block(
27042711
self,
27052712
block: torch.nn.Module,
@@ -2840,8 +2847,7 @@ def _quantize_block(
28402847
whole_indices = torch.randperm(nsamples)[:pick_samples]
28412848
# We assume the block input and output shape is same
28422849
if self.gradient_accumulate_steps != 1:
2843-
current_input_ids = [input_ids[i] for i in whole_indices]
2844-
num_elm = sum(id.numel() for id in current_input_ids)
2850+
num_elm = self._get_current_num_elm(input_ids, whole_indices)
28452851

28462852
for tmp_step in range(self.gradient_accumulate_steps):
28472853
indices = whole_indices[tmp_step * self.batch_size : (tmp_step + 1) * self.batch_size]

auto_round/compressors/diffusion/compressor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,14 @@ def _get_block_outputs(
262262

263263
return output
264264

265+
def _get_current_num_elm(
266+
self,
267+
input_ids: list[torch.Tensor],
268+
indices: list[int],
269+
) -> int:
270+
current_input_ids = [input_ids["hidden_states"][i] for i in indices]
271+
return sum(id.numel() for id in current_input_ids)
272+
265273
def calib(self, nsamples, bs):
266274
"""Perform calibration for quantization.
267275

0 commit comments

Comments
 (0)