diff --git a/bark/generation.py b/bark/generation.py index 54f98709..daf346da 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -445,16 +445,16 @@ def generate_text_semantic( with _inference_mode(): x = x.to(device) n_tot_steps = 768 + # preallocate tensor + x_initial = x.shape[1] + x = torch.hstack([x , torch.empty([1, n_tot_steps], dtype=torch.int32, device=device)]) # custom tqdm updates since we don't know when eos will occur pbar = tqdm.tqdm(disable=silent, total=n_tot_steps) pbar_state = 0 tot_generated_duration_s = 0 kv_cache = None for n in range(n_tot_steps): - if use_kv_caching and kv_cache is not None: - x_input = x[:, [-1]] - else: - x_input = x + x_input = x[:, [x_initial + n - 1]] if use_kv_caching and kv_cache is not None else x[:,:x_initial + n] logits, kv_cache = model( x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache ) @@ -485,10 +485,11 @@ def generate_text_semantic( item_next == SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) ): + n -= 1 # backtrack 1 # eos found, so break pbar.update(n - pbar_state) break - x = torch.cat((x, item_next[None]), dim=1) + x[0][x_initial + n] = item_next tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s: pbar.update(n - pbar_state) @@ -496,7 +497,6 @@ def generate_text_semantic( if n == n_tot_steps - 1: pbar.update(n - pbar_state) break - del logits, relevant_logits, probs, item_next if n > pbar_state: if n > pbar.total: @@ -506,7 +506,7 @@ def generate_text_semantic( pbar.total = n pbar.refresh() pbar.close() - out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] + out = x.detach().cpu().numpy().squeeze()[x_initial : x_initial + n + 1] if OFFLOAD_CPU: model.to("cpu") assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)