Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xtinkt committed Jul 23, 2024
1 parent 8beaf65 commit 0200664
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions src/petals/models/llama/speculative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,51 @@ def _sample(
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_batch_size: int = 10,
speculative_inference_iteration_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
print(model_kwargs)
assert not generation_config.do_sample, "sample is not working for speculative generation now"
assert not synced_gpus, "synced_gpus is not working for speculative generation now"
assert not generation_config.return_dict_in_generate, "return_dict_in_generate is not working for speculative generation now"

pad_token_id = generation_config.pad_token_id
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

assert not generation_config.do_sample, "sample is not working for speculative generation now"

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
finished = False
firsts = True

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
speculative_batch_size = min(speculative_batch_size, self.active_session._max_length - input_ids.shape[1])
while not finished:
speculative_inference_iteration_size = min(speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1])
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_batch_size,
max_new_tokens=speculative_inference_iteration_size,
do_sample=False,
use_cache=False
)
speculative_tokens = speculative_outputs[:, -speculative_batch_size:]
speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]

full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_batch_size == full_sequence.shape[1]

assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]

input_for_validation = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1:]
else:
firsts = False
input_for_validation = input_for_validation[:, :-1]
with torch.no_grad():
real_input = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
real_input = real_input[:, -speculative_batch_size - 1:]
else:
firsts = False
real_input = real_input[:, :-1]

precise_model_outputs = self(real_input, return_dict=True)
full_token_logits = precise_model_outputs.logits[:, -speculative_batch_size:, :].clone()
precise_model_outputs = self(input_for_validation)
full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()

all_valid_tokens = []

first_token = None
for i in range(speculative_batch_size):
for i in range(speculative_inference_iteration_size):
token_logits = full_token_logits[:, i, :]
valid_token = torch.argmax(token_logits, dim=-1)
token_scores = logits_processor(input_for_validation[:, :-speculative_inference_iteration_size + 1 + i], token_logits)
valid_token = torch.argmax(token_scores, dim=-1)

if first_token is None:
first_token = valid_token
Expand All @@ -88,14 +85,20 @@ def _sample(

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)

if streamer is not None:
streamer.put(all_valid_tokens.cpu())

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
this_peer_finished = unfinished_sequences.max() == 0
finished = unfinished_sequences.max() == 0

del precise_model_outputs

if streamer is not None:
streamer.end()

return input_ids

0 comments on commit 0200664

Please sign in to comment.