Skip to content

Commit 8ed81a4

Browse files
xwang365Ubuntu
authored andcommitted
support batch_size>1
clean
1 parent 700ff84 commit 8ed81a4

File tree

5 files changed

+365
-97
lines changed

5 files changed

+365
-97
lines changed

medusa/inference/inference_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
2+
"""
3+
Chat with a model with command line interface.
4+
5+
Usage:
6+
python3 -m medusa.inference.cli --model <model_name_or_path>
7+
Other commands:
8+
- Type "!!exit" or an empty line to exit.
9+
- Type "!!reset" to start a new conversation.
10+
- Type "!!remove" to remove the last prompt.
11+
- Type "!!regen" to regenerate the last message.
12+
- Type "!!save <filename>" to save the conversation history to a json file.
13+
- Type "!!load <filename>" to load a conversation history from a json file.
14+
"""
15+
import argparse
16+
import os
17+
import re
18+
import sys
19+
import torch
20+
from fastchat.serve.cli import SimpleChatIO, RichChatIO, ProgrammaticChatIO
21+
from fastchat.model.model_adapter import get_conversation_template
22+
from fastchat.conversation import get_conv_template
23+
import json
24+
from medusa.model.medusa_model import MedusaModel
25+
import pdb
26+
27+
def main(args):
28+
prefix = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {0} ASSISTANT:"
29+
# prompt = ["你叫什么名字"]
30+
prompt = ["你叫什么名字", "中国的首都是哪里呢?"]
31+
prompt = [prefix.format(p) for p in prompt]
32+
model = MedusaModel.from_pretrained(
33+
args.model,
34+
torch_dtype=torch.float16,
35+
low_cpu_mem_usage=True,
36+
device_map="auto",
37+
load_in_8bit=args.load_in_8bit,
38+
load_in_4bit=args.load_in_4bit,
39+
)
40+
tokenizer = model.get_tokenizer()
41+
# 使用tokenizer处理批量输入
42+
encoded_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
43+
# 将编码后的输入移动到模型所在的设备
44+
input_ids = encoded_inputs['input_ids'].to(model.base_model.device)
45+
attention_mask = encoded_inputs['attention_mask'].to(model.base_model.device)
46+
for output in model.medusa_generate(
47+
input_ids,
48+
attention_mask=attention_mask,
49+
temperature=args.temperature,
50+
max_steps=args.max_steps,
51+
):
52+
print(output['text'])
53+
54+
55+
if __name__ == "__main__":
56+
parser = argparse.ArgumentParser()
57+
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
58+
parser.add_argument(
59+
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
60+
)
61+
parser.add_argument(
62+
"--load-in-4bit", action="store_true", help="Use 4-bit quantization"
63+
)
64+
parser.add_argument(
65+
"--conv-template", type=str, default=None, help="Conversation prompt template."
66+
)
67+
parser.add_argument(
68+
"--conv-system-msg", type=str, default=None, help="Conversation system message."
69+
)
70+
parser.add_argument("--temperature", type=float, default=0.7)
71+
parser.add_argument("--max-steps", type=int, default=512)
72+
parser.add_argument("--no-history", action="store_true")
73+
parser.add_argument(
74+
"--style",
75+
type=str,
76+
default="simple",
77+
choices=["simple", "rich", "programmatic"],
78+
help="Display style.",
79+
)
80+
parser.add_argument(
81+
"--multiline",
82+
action="store_true",
83+
help="Enable multiline input. Use ESC+Enter for newline.",
84+
)
85+
parser.add_argument(
86+
"--mouse",
87+
action="store_true",
88+
help="[Rich Style]: Enable mouse support for cursor positioning.",
89+
)
90+
parser.add_argument(
91+
"--debug",
92+
action="store_true",
93+
help="Print useful debug information (e.g., prompts)",
94+
)
95+
args = parser.parse_args()
96+
main(args)

medusa/model/kv_cache.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
2+
import copy
33

44
class KVCache:
55
"""
@@ -41,32 +41,51 @@ def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
4141
4242
Args:
4343
indices (torch.Tensor): Indices of the data tensor to be copied.
44-
prev_length (int): Previous length before adding new data.
44+
prev_length (int): Previous lengths before adding new data
4545
dim (int, optional): Dimension along which copying should be performed. Default is 2.
4646
"""
47+
# 选取需要复制的数据
4748
tgt = self.data.index_select(dim, indices)
48-
dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
49-
dst.copy_(tgt, non_blocking=True)
50-
self.current_length.fill_(prev_length + tgt.shape[dim])
49+
prev_len = prev_length
50+
start_index = prev_len
51+
end_index = start_index + tgt.shape[dim]
52+
# 根据维度选取目标区域并复制数据
53+
if dim == 2:
54+
dst = self.data[:, :, :, start_index:end_index, :]
55+
elif dim == 3:
56+
dst = self.data[:, :, :, :, start_index:end_index]
57+
else:
58+
raise ValueError("Unsupported dimension for copying.")
59+
dst.copy_(tgt[:, :], non_blocking=True)
60+
self.current_length.fill_(prev_length + tgt.shape[dim])
5161

5262
def cat(self, tensor: torch.Tensor, dim: int = 2):
5363
"""
54-
Concatenate the given tensor with the current data.
64+
Concatenate the given tensor with the current data for batch_size > 1, and return the tensor
65+
truncated to the maximum current length across all batches.
5566
5667
Args:
57-
tensor (torch.Tensor): The tensor to be concatenated.
68+
tensor (torch.Tensor): The tensor to be concatenated, assuming the first dimension is the batch size.
5869
dim (int, optional): The dimension along which concatenation should be done. Default is 2.
5970
6071
Returns:
61-
torch.Tensor: The data tensor after concatenation up to the current length.
72+
torch.Tensor: The data tensor after concatenation and truncation to the maximum current length.
6273
"""
63-
dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
64-
dst.copy_(tensor)
74+
cur_len = copy.deepcopy(self.current_length)
75+
new_len = cur_len + tensor.size(dim)
6576
self.current_length.add_(tensor.shape[dim])
66-
return torch.narrow(self.data, 2, 0, self.current_length)
67-
68-
69-
def initialize_past_key_values(model):
77+
if dim == 2:
78+
self.data[:, :, cur_len:new_len, :] = tensor[:,:,:,:]
79+
truncated_data = self.data[:, :, :self.current_length, :]
80+
elif dim == 3:
81+
self.data[:, :, :, cur_len:new_len] = tensor[:,:,:,:]
82+
truncated_data = self.data[:, :, :, :self.current_length]
83+
else:
84+
raise ValueError("Unsupported dimension for concatenation.")
85+
return truncated_data
86+
87+
88+
def initialize_past_key_values(model, batch_size=1):
7089
"""
7190
Initialize past key and value states for a given transformer model.
7291
@@ -84,8 +103,6 @@ def initialize_past_key_values(model):
84103
"""
85104
# Extracting configuration from the model
86105
config = model.config
87-
# Initializing the batch size to 1, this can be modified if different batch sizes are required
88-
batch_size = 1
89106
# Initializing a tensor to store past keys and values for all layers
90107
past_key_values_data = torch.zeros(
91108
config.num_hidden_layers * 2,

medusa/model/medusa_model.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
44
from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM
55
# import transformers
6-
6+
import pdb
77
# # monkey patch
88
# transformers.models.llama.modeling_llama.LlamaForCausalLM = KVLlamaForCausalLM
99
# transformers.models.mistral.modeling_mistral.MistralForCausalLM = KVMistralForCausalLM
@@ -121,6 +121,7 @@ def __init__(
121121
@property
122122
def base_model(self):
123123
return self
124+
124125
@classmethod
125126
def from_pretrained(
126127
cls,
@@ -219,6 +220,7 @@ def forward(
219220
if output_orig:
220221
return torch.stack(medusa_logits, dim=0), outputs, orig
221222
return torch.stack(medusa_logits, dim=0)
223+
222224
def get_medusa_choice(self, model_name):
223225
if 'vicuna' in model_name:
224226
if '7b' in model_name:
@@ -264,10 +266,11 @@ def medusa_generate(
264266
265267
Warning: Only support batch size 1 for now!!
266268
"""
267-
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
269+
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
270+
batch_size = input_ids.shape[0]
271+
valid_length = attention_mask.sum(dim=1)
268272
# Avoid modifying the input_ids in-place
269273
input_ids = input_ids.clone()
270-
271274
# Cache medusa buffers (the fixed patterns for tree attention)
272275
if medusa_choices is None:
273276
medusa_choices = self.get_medusa_choice(self.base_model_name_or_path)
@@ -295,7 +298,7 @@ def medusa_generate(
295298
past_key_values,
296299
past_key_values_data,
297300
current_length_data,
298-
) = initialize_past_key_values(self.base_model)
301+
) = initialize_past_key_values(self.base_model, batch_size)
299302
self.past_key_values = past_key_values
300303
self.past_key_values_data = past_key_values_data
301304
self.current_length_data = current_length_data
@@ -305,12 +308,11 @@ def medusa_generate(
305308
reset_medusa_mode(self)
306309
# Initialize tree attention mask and process prefill tokens
307310
medusa_logits, logits = initialize_medusa(
308-
input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values
311+
input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values, attention_mask
309312
)
310-
311313
new_token = 0
312314
last_round_token = 0
313-
315+
ends = [input_len] * batch_size
314316
for idx in range(max_steps):
315317
# Generate candidates with topk predictions from Medusa heads
316318
candidates, tree_candidates = generate_candidates(
@@ -324,8 +326,8 @@ def medusa_generate(
324326
top_p=top_p,
325327
sampling=sampling,
326328
fast=fast,
329+
valid_length=valid_length
327330
)
328-
329331
# Use tree attention to verify the candidates and get predictions
330332
medusa_logits, logits, outputs = tree_decoding(
331333
self,
@@ -334,15 +336,14 @@ def medusa_generate(
334336
medusa_buffers["medusa_position_ids"],
335337
input_ids,
336338
medusa_buffers["retrieve_indices"],
339+
attention_mask=attention_mask
337340
)
338-
339341
# Evaluate the posterior of the candidates to select the accepted candidate prefix
340342
best_candidate, accept_length = evaluate_posterior(
341343
logits, candidates, temperature, posterior_threshold, posterior_alpha, top_p=top_p, sampling=sampling, fast=fast
342344
)
343-
344345
# Update the input_ids and logits
345-
input_ids, logits, medusa_logits, new_token = update_inference_inputs(
346+
input_ids, logits, medusa_logits, new_token, valid_length, attention_mask = update_inference_inputs(
346347
input_ids,
347348
candidates,
348349
best_candidate,
@@ -354,18 +355,29 @@ def medusa_generate(
354355
new_token,
355356
past_key_values_data,
356357
current_length_data,
358+
attention_mask=attention_mask,
359+
padding_idx=self.tokenizer.pad_token_id
357360
)
358361

359-
yield {
360-
"text": self.tokenizer.decode(
361-
input_ids[0, input_len:],
362+
decoded_texts = []
363+
eos_encountered = [False] * batch_size
364+
for i in range(batch_size):
365+
# 检查当前批次的文本是否包含结束符
366+
if self.tokenizer.eos_token_id in input_ids[i, input_len:]:
367+
eos_encountered[i] = True
368+
else:
369+
ends[i] = len(input_ids[i])
370+
decoded_text = self.tokenizer.decode(
371+
input_ids[i, input_len:ends[i]],
362372
skip_special_tokens=True,
363373
spaces_between_special_tokens=False,
364374
clean_up_tokenization_spaces=True,
365375
)
366-
}
376+
decoded_texts.append(decoded_text)
377+
yield{ "text": decoded_texts}
367378

368-
if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
379+
# 如果所有批次都遇到了 EOS,则停止
380+
if all(eos_encountered):
369381
break
370382

371383

medusa/model/modeling_llama_kv.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
if is_flash_attn_available():
3333
from flash_attn import flash_attn_func, flash_attn_varlen_func
3434
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35-
35+
import pdb
3636

3737
logger = logging.get_logger(__name__)
3838

@@ -315,7 +315,6 @@ def forward(
315315
padding_mask: Optional[torch.LongTensor] = None,
316316
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
317317
bsz, q_len, _ = hidden_states.size()
318-
319318
if self.config.pretraining_tp > 1:
320319
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
321320
query_slices = self.q_proj.weight.split(
@@ -815,6 +814,8 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
815814
# [MODIFIED] add medusa mask
816815
if hasattr(self, "medusa_mask") and self.medusa_mask is not None:
817816
medusa_mask = self.medusa_mask
817+
bs = combined_attention_mask.shape[0]
818+
medusa_mask = medusa_mask.repeat(bs,1,1,1)
818819
medusa_len = medusa_mask.size(-1)
819820
combined_attention_mask[:, :, -medusa_len:, -medusa_len:][
820821
medusa_mask == 0
@@ -886,7 +887,6 @@ def forward(
886887
padding_mask = attention_mask
887888
else:
888889
padding_mask = None
889-
890890
attention_mask = self._prepare_decoder_attention_mask(
891891
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
892892
)
@@ -1038,7 +1038,6 @@ def forward(
10381038
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
10391039
)
10401040
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1041-
10421041
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
10431042
outputs = self.model(
10441043
input_ids=input_ids,

0 commit comments

Comments
 (0)