Skip to content

Commit

Permalink
added kv_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
manmay-nakhashi committed Jul 15, 2023
1 parent 82724cc commit a88534a
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 66 deletions.
6 changes: 3 additions & 3 deletions tortoise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
pbar = None

DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
MODELS_DIR = MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
Expand Down Expand Up @@ -198,7 +198,7 @@ class TextToSpeech:
Main entry point into Tortoise.
"""

def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, use_deepspeed=False, device=None):
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False,use_deepspeed=False, device=None):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
Expand Down Expand Up @@ -229,7 +229,7 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed)
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache)

self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
Expand Down
76 changes: 26 additions & 50 deletions tortoise/models/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,50 +33,23 @@ def forward(self, x):


class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear)

# Model parallel
self.model_parallel = False
self.device_map = None
self.cached_mel_emb = None

def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True

def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
self.kv_cache = kv_cache

def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):

token_type_ids = kwargs.get("token_type_ids", None)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
Expand All @@ -88,13 +61,13 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
Expand All @@ -121,7 +94,9 @@ def forward(
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

# Create embedding
mel_len = self.cached_mel_emb.shape[1]
Expand All @@ -130,14 +105,17 @@ def forward(
text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb.repeat_interleave(
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
)
else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)

emb = emb + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - mel_len, attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
Expand All @@ -153,12 +131,6 @@ def forward(
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]

# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)

lm_logits = self.lm_head(hidden_states)

if not return_dict:
Expand All @@ -181,7 +153,10 @@ def _reorder_cache(past, beam_idx):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past
)

Expand Down Expand Up @@ -340,7 +315,7 @@ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_me
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False):
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
Expand All @@ -358,7 +333,8 @@ def post_init_gpt2_config(self, use_deepspeed=False):
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head
self.mel_head,
kv_cache=kv_cache,
)
if use_deepspeed:
import deepspeed
Expand Down
Loading

0 comments on commit a88534a

Please sign in to comment.