diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..d83b8aa --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,33 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with pytest + run: | + pytest -s test.py diff --git a/configs/dalle_coco.json b/configs/dalle_coco.json index 7d2c768..0454b98 100644 --- a/configs/dalle_coco.json +++ b/configs/dalle_coco.json @@ -7,22 +7,22 @@ }, "train_batch_size": 128, "eval_batch_size": 128, - "predict_batch_size": 128, + "predict_batch_size": 16, "steps_per_checkpoint": 5000, "iterations": 1000, "train_steps": 100000, "predict_steps": 0, "eval_steps": 0, "n_channels": 3, - "bf_16": false, + "bf_16": true, "recompute_grad": true, "lr": 0.0001, - "model_path": "gs://neo-models/dalle_coco/", + "model_path": "gs://neo-models/dalle_coco_sample/", "mesh_shape": "data:16,model:2", - "layout": "batch_dim:data", + "layout": "batch_dim:data,embed_dim:model", "n_embd": 1024, "text_vocab_size": 50258, - "image_vocab_size": 512, + "image_vocab_size": 2048, "text_seq_len": 256, "n_layers": 12, "n_heads": 8, diff --git a/src/dalle_mtf/__init__.py b/src/dalle_mtf/__init__.py index a53a710..d9d604b 100644 --- a/src/dalle_mtf/__init__.py +++ b/src/dalle_mtf/__init__.py @@ -1 +1,2 @@ -from .models import DALLE, DiscreteVAE \ No newline at end of file +from .models import DALLE, DiscreteVAE +from .sample import sample_autoregressive \ No newline at end of file diff --git a/src/dalle_mtf/models.py b/src/dalle_mtf/models.py index 7bc7474..be258d7 100644 --- a/src/dalle_mtf/models.py +++ b/src/dalle_mtf/models.py @@ -5,7 +5,7 @@ from collections import defaultdict import math -from .ops import pad, exists, get_variable_dtype +from .ops import pad, exists, get_variable_dtype, expand_tile, mask_to_bias from .layers import gumbel_softmax, mse_loss, norm @@ -140,29 +140,34 @@ def forward(self, features, return_recon_loss=False, return_logits=False, hard_g class DALLE: - def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, + def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024, n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train", - is_incremental_inference=False, context=None, loss_fn=None, params=None, eos_token_id=None, - activation_fn=None): + is_incremental_inference=False, context=None, loss_fn=None, params=None, padding_id=None, + activation_fn=None, text_loss_weight=0.15, unique_pad_tokens = False): + self.mesh = mesh self.n_embd = n_embd - self.text_vocab_size = text_vocab_size + self.unique_pad_tokens = unique_pad_tokens + self.text_vocab_size = text_vocab_size + (0 if not unique_pad_tokens else text_seq_len) self.image_vocab_size = image_vocab_size self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len - self.total_seq_dim = text_seq_len + image_seq_len + self.total_seq_len = text_seq_len + image_seq_len self.n_layers = n_layers self.n_heads = n_heads self.attn_mask = attn_mask - self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS - self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id + self.logits_mask = None + + self.text_loss_weight = text_loss_weight + self.padding_id = 0 if padding_id is None else padding_id self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd), "text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size), "image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size), - "final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens), - "total_seq_dim": mtf.Dimension("total_seq_dim", self.total_seq_dim), - "embed_seq_dim": mtf.Dimension("embed_seq_dim", self.total_seq_dim), - "memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_dim), + "text_sequence_dim": mtf.Dimension("sequence_dim", text_seq_len), + "image_sequence_dim": mtf.Dimension("sequence_dim", image_seq_len), + "total_seq_dim": mtf.Dimension("sequence_dim", self.total_seq_len), + "embed_seq_dim": mtf.Dimension("embed_seq_dim", self.total_seq_len), + "memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_len), "heads_dim": mtf.Dimension("heads", n_heads), "kv_dim": mtf.Dimension("kv_dim", n_embd // n_heads), "batch_dim": mtf.Dimension("batch_dim", batch_size)} @@ -179,13 +184,17 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq self.activation_fn = activation_fn if self.is_incremental_inference: assert self.context is not None, "must have context in incremental inference" + assert self.context['mode'] == 'incremental' if params is None: # extra params params = {} self.params = defaultdict(lambda: None, params) def embedding(self, x, name): embd_dim = self.dimensions["embed_dim"] - vocab_dim = self.dimensions["final_vocab_dim"] + if "text" in name: + vocab_dim = self.dimensions["text_vocab_dim"] + else: + vocab_dim = self.dimensions["image_vocab_dim"] with tf.variable_scope(name): wte = mtf.get_variable(x.mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), @@ -201,6 +210,10 @@ def embedding(self, x, name): return x def positional_embedding(self, x, name): + if "text" in name: + sequence_dim = self.dimensions["text_sequence_dim"] + else: + sequence_dim = self.dimensions["image_sequence_dim"] with tf.variable_scope(name): # Positional embedding wpe = mtf.get_variable(x.mesh, "wpe", @@ -209,7 +222,7 @@ def positional_embedding(self, x, name): master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) - position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ + position_indices = mtf.range(x.mesh, sequence_dim, tf.int64) if not \ self.is_incremental_inference else (self.context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) embed_dropout = self.params.get("embed_dropout", 0) @@ -227,8 +240,13 @@ def get_attn_mask(self, mesh, nd, ns): return self.attn_mask def attention(self, x, n_state, mask, attention_type="global", name="attn"): - # x :: [batch, seq, n_embd] - batch_dim, seq_dim, embd_dim = x_shape = x.shape + if not self.is_incremental_inference: + # x :: [batch, seq, n_embd] + batch_dim, seq_dim, embd_dim = x_shape = x.shape + else: + batch_dim, embd_dim = x_shape = x.shape + seq_dim = self.dimensions['total_seq_dim'] + assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs @@ -254,25 +272,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): - if attention_type == "local": - # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. - radius = self.params.get("local_attention_radius", 256) - if self.is_incremental_inference: - q *= one_hot - a = mtf_transformer.attention.local_attention_1d( - q, k, v, - length_dim=k.shape[1], - key_dim=self.dimensions["kv_dim"], - value_dim=self.dimensions["kv_dim"], - radius=radius, - length_dim_num_splits=1, - fully_autoregressive=True, - attention_kwargs={}, - ) - if self.is_incremental_inference: - a = mtf.gather(a, self.context.position - 1, seq_dim) - - elif attention_type == "global": + if attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast(mask, @@ -347,7 +347,8 @@ def transformer(self, x, mask): def _loss(self, logits, labels): with tf.variable_scope("loss_final"): - loss_batch = self.loss_fn(logits=logits, targets=labels, + loss_batch = self.loss_fn(logits =mtf.slice(logits, begin=self.text_seq_len, size=self.image_seq_len, slice_dim_name="sequence_dim"), + targets=mtf.slice(labels, begin=self.text_seq_len, size=self.image_seq_len, slice_dim_name="sequence_dim"), vocab_dim=logits.shape[-1], z_loss=0.0) with tf.variable_scope("reduce_mean_final"): @@ -370,6 +371,59 @@ def linear(self, x, new_dim, w_init_stdev=0.02, params=None, scale=False, name=" kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev), variable_dtype=self.variable_dtype) + def axial_positional_embedding(self, mesh, name): + with tf.variable_scope(name): + axial_dim_side = int(sqrt(self.image_seq_len)) + + embd_dim = self.dimensions["embed_dim"] + axial_dim = mtf.Dimension("axial_dim", self.image_seq_len) + + dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_side, axial_dim_side))] + + axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + + axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + + axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), + (axial_wpe_1, axial_wpe_2)) + wpe = (axial_wpe_1 + axial_wpe_2) / 2 + + wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) + wpe = pad(wpe, [self.text_seq_len + 1, 0], axial_dim.name) + wpe = mtf.slice(wpe, 0, self.total_seq_len, axial_dim.name) + wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["embed_seq_dim"]) + return wpe + + + def absolute_positional_embedding(self, mesh, name): + with tf.variable_scope(name): + # Positional embedding + wpe = mtf.get_variable(mesh, "wpe", + mtf.Shape([self.dimensions["embed_seq_dim"], self.dimensions["embed_dim"]]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=self.variable_dtype.master_dtype, + slice_dtype=self.variable_dtype.slice_dtype, + activation_dtype=self.variable_dtype.activation_dtype) + return wpe + + def apply_positional_embedding(self, x, wpe): + position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ + self.is_incremental_inference else (self.context.position - 1) + pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) + embed_dropout = self.params.get("embed_dropout", 0) + if embed_dropout > 0 and self.mode == "train": + pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout") + x += pos_emb + return x + def layer_norm(self, x, name="layer_norm", axis=None, epsilon=1e-5): """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" if axis is None: @@ -392,25 +446,109 @@ def to_logits(self, x): with tf.variable_scope("to_logits"): logits = self.linear(self.layer_norm(x), self.dimensions["final_vocab_dim"], name="linear_out") # Go to full precision for the logits + if self.is_incremental_inference: + # add seq dim in inference mode + logits = expand_tile(logits, mtf.Dimension("sequence_dim", 1), axis=1) return mtf.cast(logits, tf.float32) + def to_image_logits(self, x): + with tf.variable_scope("to_image_logits"): + if not self.is_incremental_inference: + x = mtf.slice(x, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = x.shape[1].name) + + image_logits = self.linear(x, self.dimensions["image_vocab_dim"], name="linear_image_out") + + # Go to full precision for the logits + image_logits = mtf.cast(image_logits, tf.float32) + return image_logits + + def to_text_logits(self, x): + with tf.variable_scope("to_text_logits"): + text_tokens = mtf.slice(x, begin = 0, size = self.text_seq_len, slice_dim_name = x.shape[1].name) + text_logits = self.linear(text_tokens, self.dimensions["text_vocab_dim"], name="linear_text_out") + + # Go to full precision for the logits + text_logits = mtf.cast(text_logits, tf.float32) + return text_logits + + def _loss(self, text_logits, image_logits, text_labels, image_labels): + with tf.variable_scope("loss_final"): + text_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels, + vocab_dim=text_logits.shape[-1], z_loss=0.0) + + image_loss_batch = self.loss_fn(logits=image_logits, targets=image_labels, + vocab_dim=image_logits.shape[-1], z_loss=0.0) + + loss_batch = text_loss_batch * self.text_loss_weight + image_loss_batch + + with tf.variable_scope("reduce_mean_final"): + loss = mtf.reduce_mean(loss_batch) + + loss /= self.params.get("num_microbatches", 1) + # Convert to train dtype + loss = mtf.cast(loss, self.variable_dtype.slice_dtype) + return loss, loss_batch # loss batch must be returned for metric fns + def forward(self, features, return_loss=True, return_logits=False): - inputs = features["tokens"] - tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") + if features.get('text_inputs') is not None: + text = features["text_inputs"] + + if self.unique_pad_tokens: + input_range = mtf.range(text.mesh, text.shape[1], tf.int32) + pad_mask = mtf.equal(text, 0) + pad_token_ids = input_range + self.text_seq_len # shift to the range of pad token ids, which come after text token ids, and before image token ids + text = mtf.where(pad_mask, pad_token_ids, text) - mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) + text_with_bos = pad(text, [1, 0], dim_name = text.shape[1].name, pad_value = self.padding_id) + text_emb = self.embedding(text_with_bos, "text_embd") + else: + assert self.is_incremental_inference + + image = features.get("image_inputs", None) + + if not self.is_incremental_inference: + image_input = mtf.slice(image, 0, self.image_seq_len - 1, image.shape[1].name) + image_emb = self.embedding(image_input, "image_embd") + tokens = mtf.concat([text_emb, image_emb], concat_dim_name="sequence_dim") # [batch, seq, n_embd] + else: + # reshape inputs if in inference mode + image = mtf.gather(image, self.context.position - 1, self.dimensions["image_sequence_dim"]) + image = mtf.reshape(image, [self.dimensions["batch_dim"]]) + tokens = self.embedding(image, "image_embd") + + # positional embedding + + abs_pos_emb = self.absolute_positional_embedding(tokens.mesh, "positional_embedding") + axial_pos_emb = self.axial_positional_embedding(tokens.mesh, "axial_positional_embedding") + + tokens = self.apply_positional_embedding(tokens, abs_pos_emb) + tokens = self.apply_positional_embedding(tokens, axial_pos_emb) + + # attention + + mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) - logits = self.to_logits(out) + + # to logits + + image_logits = self.to_image_logits(out) + if not return_loss: + logits = mtf.cast(image_logits, self.variable_dtype.master_dtype) return logits - labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) - indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 - labels = mtf.gather(labels, indices, dim=labels.shape[1]) - labels = mtf.rename_dimension(labels, "range", "total_seq_dim") - loss, loss_batch = self._loss(logits, labels) + assert exists(image), 'when training, image must be supplied' + labels = mtf.concat([text, image], concat_dim_name="sequence_dim") + + text_logits = self.to_text_logits(out) + + text_labels = mtf.slice(labels, begin = 0, size = self.text_seq_len, slice_dim_name = labels.shape[1].name) + image_labels = mtf.slice(labels, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = labels.shape[1].name) + + loss, loss_batch = self._loss(text_logits, image_logits, text_labels, image_labels) + if return_logits and return_loss: # Cast back to checkpoint dtype - logits = mtf.cast(logits, self.variable_dtype.master_dtype) + logits = mtf.cast(image_logits, self.variable_dtype.master_dtype) return loss, loss_batch, logits return loss, loss_batch diff --git a/src/dalle_mtf/ops.py b/src/dalle_mtf/ops.py index f679170..121b7f3 100644 --- a/src/dalle_mtf/ops.py +++ b/src/dalle_mtf/ops.py @@ -80,3 +80,21 @@ def get_variable_dtype(bf_16=True): return mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: return mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) + +def expand_tile(value, newdim, axis=0): + """Add a new axis of given size.""" + new_shape = value.shape.dims + new_shape.insert(axis, newdim) + return mtf.broadcast(value, new_shape) # shape.dims gets us a list which we need in order to concat + +def mask_to_bias(visible, dtype): + """Convert a boolean visibility mask to an attention bias. + The returned Tensor has large negative values in positions where + visible=False. + Args: + visible: a boolean Tensor + dtype: a dtype + Returns: + a Tensor with the given dtype and the same shape as "visible" + """ + return mtf.cast(mtf.logical_not(visible), dtype) * -1e9 diff --git a/src/dalle_mtf/sample.py b/src/dalle_mtf/sample.py new file mode 100644 index 0000000..39abc12 --- /dev/null +++ b/src/dalle_mtf/sample.py @@ -0,0 +1,167 @@ +import mesh_tensorflow as mtf +import tensorflow.compat.v1 as tf +import mesh_tensorflow.transformer as mtf_transformer + + +def sample_autoregressive(inputs, + model, + max_steps=None, + temperature=0.9, + padding_id = 0, + variable_dtype=mtf.VariableDType(tf.float32), + sampling_keep_top_k=-1, + cached=True + ): + """Sample randomly one token at a time. + + The partial_sequences represent partial sequences to be continued. The + first tokens of each sequence are nonzero representing the given partial + sequences and the last tokens of each sequence are zeros, representing what + needs to be filled in. + + If there are no partial sequences (you want to sample from the beginning), + then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and + + Args: + inputs: an input dictionary containing 'text_inputs' and 'image_inputs', + model: DALL-E model + stop_at_token: an optional integer eos id. Stop when we produce it. + max_steps: an optional integer, the max number of steps to decode. + temperature: an optional floating point value between 0.0 and 1.0 0.0 + means argmax, 1.0 means sample according to predicted distribution. + variable_dtype: a mtf.VariableDType + decoding, one per each input layer + the embedding layer + sampling_keep_top_k: an integer - if not -1, only sample from the top k + logits. + + Returns: + a Tensor with shape [, length_dim] + """ + + # with dalle, inputs will be a text sequence of len 256, then the rest image tokens. + # the parts we want to fill in will be <|pad_token|>, which we should assign in the input + + batch_dims = model.dimensions["batch_dim"] + length_dim = model.dimensions["total_seq_dim"] + image_seq_dim = model.dimensions['image_sequence_dim'] + + image_inputs = inputs['image_inputs'] + text_inputs = inputs['text_inputs'] + + # Gets position (in image inputs) where zero padding starts + initial_position = mtf.zeros(text_inputs.mesh, mtf.Shape((batch_dims,)), dtype=tf.int32) + + # initial_position += model.dimensions['text_seq_dim'].size + + length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32) + + # one step of sampling fn + + def sample_step(logits, ids, position, incremental): + nonlocal sampling_keep_top_k + # By default, do top_k sampling of 0.9 + if sampling_keep_top_k == -2: + sampling_keep_top_k = int(logits.shape[-1].size * 0.1) + + if sampling_keep_top_k != -1: + if sampling_keep_top_k <= 0: + raise ValueError("sampling_keep_top_k must either be -1 or positive.") + k_largest = mtf.nth_largest_element( + logits, n=sampling_keep_top_k, + reduced_dim=model.dimensions['image_vocab_dim']) + logits = mtf.where(mtf.less_equal(logits, k_largest), + mtf.ones_like(logits) * -1e6, logits) + + # temperature sampling + ids_this_step = mtf.sample_with_temperature( + logits, model.dimensions['image_vocab_dim'], temperature) + + # reshape & assign results + if incremental: + ids_this_step = mtf.reshape(ids_this_step, ([batch_dims])) + else: + ids_this_step = mtf.shift(ids_this_step, offset=1, dim=image_seq_dim, wrap=False) + + one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32) + one_new_id = ids_this_step * one_hot + new_ids = (1 - one_hot) * ids + one_new_id + new_position = position + 1 + return [new_position, new_ids] + + # Builds context to pass around internally + # The 'first part' context records initial states of k / v / x + if cached: + context_first_part = mtf_transformer.transformer.Context( + model=None, + mesh=image_inputs.mesh, + batch_dims=batch_dims, + length_dim=image_seq_dim, + variable_dtype=variable_dtype, + mode="first_part", + position=length_range, + position_is_default=True, + new_states=[], + initial_position=initial_position, + sequence_id=None, + constant_states=[], + inputs=inputs) + model.context = context_first_part + + with tf.variable_scope('dall-e'): + logits = model.forward(inputs, return_loss=False, return_logits=True) + + initial_states = context_first_part.new_states + + # sample one step to get first image token and then delete logits + + initial_position, image_inputs = sample_step(logits, image_inputs, initial_position, incremental = False) + + del logits + else: + initial_states = [] + + def cond_fn(position, ids, *unused_states): + """Should we run another loop iteration?""" + past_end = mtf.greater_equal(position, image_seq_dim.size) + if max_steps: + past_end = mtf.logical_or( + past_end, mtf.greater_equal(position - initial_position, max_steps)) + + is_done = past_end + all_done = mtf.reduce_all(is_done) + return mtf.logical_not(all_done) + + def body_fn(position, ids, *states): + """One step in the decode loop.""" + + context = mtf_transformer.transformer.Context( + model=None, + mesh=image_inputs.mesh, + batch_dims=batch_dims, + length_dim=image_seq_dim, + variable_dtype=variable_dtype, + mode="incremental", + position=position, + position_is_default=True, + states=states, + new_states=[], + initial_position=position, + sequence_id=None, + inputs=ids) if cached else None + + model.is_incremental_inference = True if cached else False + model.context = context + with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE): + logits = model.forward({'image_inputs': image_inputs, 'text_inputs': (text_inputs if not cached else None)}, return_loss=False, return_logits=True) + + ret = sample_step(logits, ids, position, cached) + + if cached: + ret += context.new_states + return ret + + while_loop_inputs = [initial_position, image_inputs] + initial_states + final_position, outputs = mtf.while_loop( + cond_fn, body_fn, while_loop_inputs)[:2] + del final_position + return outputs diff --git a/src/input_fns.py b/src/input_fns.py index ee35bfc..b1c07a8 100644 --- a/src/input_fns.py +++ b/src/input_fns.py @@ -1,5 +1,7 @@ +import imageio +import numpy as np import tensorflow.compat.v1 as tf - +import os def crop_center_and_resize(img, size): s = tf.shape(img) @@ -38,6 +40,32 @@ def truncate_or_pad_label(label, params): return label +def pred_input(params, tokenizer, prompt='a cat in a hat'): + tokens = tokenizer.encode(prompt) + if len(tokens) > params["text_seq_len"]: + tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating " + "input.") + tokens = tokens[len(tokens) - params["text_seq_len"]:] # TODO: left or right truncate here? + if len(tokens) < params["text_seq_len"]: + tokens = tf.pad(tokens, [[0, params["text_seq_len"] - len(tokens)]], constant_values=params["padding_id"]) + t = tf.broadcast_to(tokens, [params["batch_size"], params["text_seq_len"]]) + dataset = tf.data.Dataset.from_tensors(t) + + def _dummy_labels(x): + return x, x + + dataset = dataset.map(_dummy_labels) + return dataset + + +def pred_output(predictions, out_name='test', output_dir='outputs'): + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + for i, p in enumerate(predictions): + denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8) + imageio.imwrite(f"outputs/{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"])) + + def read_labeled_tfrecord(params): def read_fn(example): features = { @@ -103,6 +131,7 @@ def _process_path(file_path): dataset = configure_for_performance(dataset, params, eval) return dataset.repeat() + def dalle_input_fn(params, eval=False): path = params["dataset"]["train_path"] if not eval else params["dataset"]["eval_path"] files = tf.io.gfile.glob(path) @@ -113,7 +142,8 @@ def dalle_input_fn(params, eval=False): if not eval: dataset = dataset.shuffle(file_count, reshuffle_each_iteration=False) - dataset = dataset.apply(tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) + dataset = dataset.apply( + tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) parse_fn = read_labeled_tfrecord(params) dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = configure_for_performance(dataset, params, eval) diff --git a/src/model_fns.py b/src/model_fns.py index ba0a5d9..a2b2841 100644 --- a/src/model_fns.py +++ b/src/model_fns.py @@ -1,12 +1,15 @@ import mesh_tensorflow as mtf import tensorflow.compat.v1 as tf + from tensorflow.python.tpu import tpu_estimator import mesh_tensorflow.transformer as mtf_transformer from .optimizers import get_optimizer from .utils import mode_to_str, get_graph_info, create_host_call, simd_mesh_setup, scalar_summary -from .dalle_mtf import DALLE +from .dalle_mtf import DALLE, sample_autoregressive from .vae_tf import DiscreteVAE - +from .dalle_mtf.ops import mask_to_bias +from tensorflow.python.ops import resources +import numpy as np def initialize_vae_weights(checkpoint_path, scope="vae"): """ @@ -16,7 +19,7 @@ def initialize_vae_weights(checkpoint_path, scope="vae"): vars_to_restore = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) ckpt_vars = [ - name for name, _ in tf.train.list_variables(checkpoint_path)] + name for name, _ in tf.train.list_variables(checkpoint_path)] tf.logging.info(f"RESTORING {len(vars_to_restore)} VAE VARS FROM CHECKPOINT: ") tf.logging.info(f"CHECKPOINT PATH: {checkpoint_path}") tf.logging.info(f"CHECKPOINT VARS:") @@ -62,19 +65,21 @@ def dalle_model_fn(features, labels, mode, params): # load vae in tensorflow graph before mtf vae, vae_checkpoint_path = load_vae_model(params, mode_str) - initialize_vae_weights(vae_checkpoint_path) - H = W = params["dataset"]["image_size"] - image_seq_len = (vae.H // (2 ** len(vae.convblocks))) ** 2 // (vae.stack_factor ** 2) # TODO: check this is correct batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: - with tf.variable_scope("vae"): - vae_logits = vae.forward(features, return_logits=True) + with tf.variable_scope("vae"): + vae_logits = vae.forward(features, return_logits=True) - # TODO: using argmax sampling for now, but is that optimal? - tokens = tf.math.argmax(vae_logits, -1) - img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, image_seq_len)), tf.int32) + # TODO: using argmax sampling for now, but is that optimal? + tokens = tf.math.argmax(vae_logits, -1) + img_tokens_reshaped = tf.cast(tf.reshape(tokens, (batch_size, params['image_seq_len'])), tf.int32) + + # TODO: get rid of this ugly hack, its just to pull the decoder parameters in during training + with tf.variable_scope('vae'): + vae.decoder(tf.zeros_like(vae_logits)) # Construct mtf graph + mesh from params graph = mtf.Graph() @@ -94,11 +99,12 @@ def dalle_model_fn(features, labels, mode, params): mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = DALLE( + mesh=mesh, n_embd=params["n_embd"], text_vocab_size=params["text_vocab_size"], image_vocab_size=params["image_vocab_size"], text_seq_len=params["text_seq_len"], - image_seq_len=image_seq_len, + image_seq_len=params['image_seq_len'], n_layers=params["n_layers"], n_heads=params["n_heads"], batch_size=batch_size, @@ -107,33 +113,85 @@ def dalle_model_fn(features, labels, mode, params): params=params, ) - # Build mtf_features & seq length dict for getting number of microbatches - # We need to pack inputs into a dict to pass into serialize_training_step - features_dict = {"image_inputs": features, - "text_inputs": labels} - mtf_features = {} - for key, x in features_dict.items(): - if x is not None: - if key == "text_inputs": - text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) - x = tf.concat((text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) - mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["total_seq_dim"]]) - - mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - - if key == "image_inputs": - mtf_shape = mtf.Shape([ - model.dimensions["batch_dim"], - mtf.Dimension("img_height_dim", vae.H), - mtf.Dimension("img_width_dim", vae.W), - mtf.Dimension("img_channel_dim", vae.num_ch), - ]) - x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC - mtf_features["image_inputs"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) - - scalar_summary("input_image", mtf_features["image_inputs"]) - if mode == tf.estimator.ModeKeys.PREDICT: - raise NotImplementedError + if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: + # Build mtf_features & seq length dict for getting number of microbatches + # We need to pack inputs into a dict to pass into serialize_training_step + features_dict = {"image_inputs": img_tokens_reshaped, + "text_inputs": labels} + mtf_features = {} + for key, x in features_dict.items(): + if x is not None: + if key == "text_inputs": + x = tf.reshape(x, [batch_size, params["text_seq_len"]]) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["text_sequence_dim"]]) + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + if key == "image_inputs": + mtf_shape = mtf.Shape([ + model.dimensions["batch_dim"], + model.dimensions["image_sequence_dim"], + ]) + mtf_features[key] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + else: + # Build mtf_features & seq length dict for getting number of microbatches + # We need to pack inputs into a dict to pass into serialize_training_step + features_dict = {"text_inputs": labels, 'image_inputs': 'None'} + mtf_features = {} + for key, x in features_dict.items(): + if x is not None: + if key == "text_inputs": + x = tf.reshape(x, [batch_size, params["text_seq_len"]]) + mtf_shape = mtf.Shape([model.dimensions["batch_dim"], model.dimensions["text_sequence_dim"]]) + mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + mtf_features[key] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) + if key == "image_inputs": + mtf_shape = mtf.Shape([ + model.dimensions["batch_dim"], + model.dimensions["image_sequence_dim"], + ]) + mtf_features[key] = mtf.zeros(mesh, mtf_shape, tf.int32) + params['padding_id'] + + # Set up the model for prediction + mtf_samples = sample_autoregressive(mtf_features, + model, + max_steps=model.image_seq_len, + temperature=0.9, + variable_dtype=model.variable_dtype, + sampling_keep_top_k=-2, + ) + + mtf_samples = mtf.anonymize(mtf_samples) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) + + outputs = lowering.export_to_tf_tensor(mtf_samples) + + initialize_vae_weights(vae_checkpoint_path) + + outputs -= model.text_vocab_size + with tf.variable_scope('vae'): + predictions_decoded = vae.decode(outputs) + + predictions = { + "outputs": outputs, + "predictions_decoded": predictions_decoded + } + denormalize = lambda x: (((x + 1) / 2) * 255.0) + def scaffold_fn(): + return tf.train.Scaffold( + local_init_op=tf.group( + tf.train.Scaffold.default_local_init_op(), + lowering.copy_masters_to_slices(), + name="mtf_local_init_op"), + ready_op=tf.concat( + [tf.report_uninitialized_variables(), + resources.report_uninitialized_resources()], + axis=0, + name="mtf_ready_op")) + + return tpu_estimator.TPUEstimatorSpec( + mode=tf.estimator.ModeKeys.PREDICT, + predictions=predictions, + scaffold_fn=scaffold_fn, + prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) @@ -142,7 +200,7 @@ def dalle_model_fn(features, labels, mode, params): # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=model.dimensions["batch_dim"], - sequence_length=model.total_seq_dim, + sequence_length=model.total_seq_len, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica= @@ -156,8 +214,9 @@ def dalle_model_fn(features, labels, mode, params): if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): - loss, loss_batch = model.forward(mtf_features, return_loss=True) - return {"loss": loss, "loss_batch": loss_batch} + with tf.variable_scope('dall-e'): + loss, loss_batch = model.forward(mtf_features, return_loss=True) + return {"loss": loss, "loss_batch": loss_batch} # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, model.dimensions["batch_dim"], @@ -165,7 +224,8 @@ def serialized_fn(mtf_features): loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] else: - loss, loss_batch = model.forward(mtf_features, return_loss=True) + with tf.variable_scope('dall-e'): + loss, loss_batch = model.forward(mtf_features, return_loss=True) del loss_batch # TODO: may need this for some metrics - otherwise, remove from output @@ -186,11 +246,12 @@ def serialized_fn(mtf_features): get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors - lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=params.get('autostack', True)) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) + if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) @@ -200,8 +261,12 @@ def serialized_fn(mtf_features): tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step train_op = tf.group(tf_update_ops) + with mtf.utils.outside_all_rewrites(): + # only *now* can we initialize vae weights (stupid tensorflow) + initialize_vae_weights(vae_checkpoint_path) + # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: diff --git a/src/model_fns_tf.py b/src/model_fns_tf.py index 9c16d55..b84608f 100644 --- a/src/model_fns_tf.py +++ b/src/model_fns_tf.py @@ -1,9 +1,8 @@ import tensorflow.compat.v1 as tf import tensorflow.compat.v2 as tf2 from tensorflow.python.tpu import tpu_estimator -from .optimizers import get_optimizer from .vae_tf import DiscreteVAE -from .utils import scalar_summary, mode_to_str, create_host_call +from .utils import mode_to_str def vae_model_fn(features, labels, mode, params): diff --git a/src/optimizers.py b/src/optimizers.py index 7f77c04..42a627c 100644 --- a/src/optimizers.py +++ b/src/optimizers.py @@ -79,7 +79,7 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): scalar_summary("lr", learning_rate) if optimizer_name.lower() == "adam": - optimizer = mtf.optimize.AdamWeightDecayOptimizer( + optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params.get("weight_decay", 0.0), beta_1=params.get("beta_1", 0.9), @@ -104,85 +104,85 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): return learning_rate, update_ops, var_grads_fp -# class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): -# """A basic Adam optimizer that includes "correct" L2 weight decay.""" - -# def __init__(self, -# learning_rate, -# weight_decay_rate=0.0, -# beta_1=0.9, -# beta_2=0.999, -# epsilon=1e-6, -# exclude_from_weight_decay=None, -# variable_dtype=None): -# """Constructs a AdamWeightDecayOptimizer.""" - -# self.learning_rate = learning_rate -# self.weight_decay_rate = weight_decay_rate -# self.beta_1 = beta_1 -# self.beta_2 = beta_2 -# self.epsilon = epsilon -# self.exclude_from_weight_decay = exclude_from_weight_decay -# self.variable_dtype = variable_dtype - -# def apply_grad(self, grad, var): -# """See base class.""" -# if grad is None: -# tf.logging.warning("Gradient is None for variable %s" % var.name) -# return [] - -# grad = mtf.to_float(grad) - -# assignments = [] - -# m = mtf.get_variable( -# var.mesh, var.name + "/adam_m", var.shape, -# initializer=tf.zeros_initializer(), -# # master_dtype=self.variable_dtype.master_dtype, -# # slice_dtype=self.variable_dtype.slice_dtype, -# # activation_dtype=self.variable_dtype.activation_dtype, -# trainable=False) - -# v = mtf.get_variable( -# var.mesh, var.name + "/adam_v", var.shape, -# initializer=tf.zeros_initializer(), -# # master_dtype=self.variable_dtype.master_dtype, -# # slice_dtype=self.variable_dtype.slice_dtype, -# # activation_dtype=self.variable_dtype.activation_dtype, -# trainable=False) - -# # Standard Adam update. -# next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad -# next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) - -# update = next_m / (mtf.sqrt(next_v) + self.epsilon) - -# # Just adding the square of the weights to the loss function is *not* -# # the correct way of using L2 regularization/weight decay with Adam, -# # since that will interact with the m and v parameters in strange ways. -# # -# # Instead we want to decay the weights in a manner that doesn't interact -# # with the m/v parameters. This is equivalent to adding the square -# # of the weights to the loss with plain (non-momentum) SGD. -# if self._do_use_weight_decay(var.name): -# update += mtf.to_float(var.value) * self.weight_decay_rate - -# update_with_lr = self.learning_rate * update - -# var_update = mtf.assign_sub(var, update_with_lr) - -# assignments.extend( -# [var_update, -# mtf.assign(m, next_m), -# mtf.assign(v, next_v)]) -# return assignments - -# def _do_use_weight_decay(self, param_name): -# """Whether to use L2 weight decay for `param_name`.""" -# if not self.weight_decay_rate: -# return False -# if self.exclude_from_weight_decay: -# for r in self.exclude_from_weight_decay: -# if re.search(r, param_name) is not None: -# return False -# return True +class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): + """A basic Adam optimizer that includes "correct" L2 weight decay.""" + + def __init__(self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + variable_dtype=None): + """Constructs a AdamWeightDecayOptimizer.""" + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + self.variable_dtype = variable_dtype + + def apply_grad(self, grad, var): + """See base class.""" + if grad is None: + tf.logging.warning("Gradient is None for variable %s" % var.name) + return [] + + grad = mtf.to_float(grad) + + assignments = [] + + m = mtf.get_variable( + var.mesh, var.name + "/adam_m", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + v = mtf.get_variable( + var.mesh, var.name + "/adam_v", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + # Standard Adam update. + next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad + next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) + + update = next_m / (mtf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(var.name): + update += mtf.to_float(var.value) * self.weight_decay_rate + + update_with_lr = self.learning_rate * update + + var_update = mtf.assign_sub(var, update_with_lr) + + assignments.extend( + [var_update, + mtf.assign(m, next_m), + mtf.assign(v, next_v)]) + return assignments + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True diff --git a/src/utils/utils.py b/src/utils/utils.py index 45178d6..70869d5 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -9,6 +9,8 @@ import logging import sys from mesh_tensorflow.ops import Operation, Tensor +import re + def fetch_model_params(model): model_path = model if model.endswith(".json") else f"./configs/{model}.json" @@ -66,7 +68,7 @@ def get_n_trainable_vars(graph): for dim in shape: variable_parameters *= dim.size total_parameters += variable_parameters - print(f"\n\nN PARAMS:\n{total_parameters:,}\n\n") + tf.logging.info(f"\n\nN PARAMS:\n{total_parameters:,}\n\n") def print_dim_names(graph): @@ -83,10 +85,10 @@ def print_dim_names(graph): # Print all dim names in graph & write to file all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims unique_dims = list(set(all_dim_names)) - print("ALL DIM NAMES:") + tf.logging.info("ALL DIM NAMES:") for dim_name in unique_dims: - print(dim_name) - print('\n') + tf.logging.info(dim_name) + tf.logging.info('\n') def get_graph_info(graph): @@ -224,4 +226,37 @@ def scalar_summary(name, x): Returns: a Tensor which is identical in value to x """ - return ScalarSummaryOperation(name, x) \ No newline at end of file + return ScalarSummaryOperation(name, x) + + +def get_image_seq_len(dalle_params): + return (dalle_params["vae_params"]['dataset']['image_size'] // (2 ** len(dalle_params["vae_params"]['convblocks']))) ** 2 // ( + dalle_params.get("vae_params").get("stack_factor", 1) ** 2) + +def save_config(params_dict, logdir): + tf.logging.info(f"Saving config to {logdir}") + text = "{\n\n" + total_params = len(params_dict) + for count, key in enumerate(params_dict): + config_value = str(params_dict[key]) + if re.search('[a-zA-Z]', config_value): + if config_value.lower() != 'true': + if config_value.lower() != 'false': + if config_value[0] != '[': + # TODO: Making a manual exception for parsing epsilon right now since it's the only number in + # scientific notation. Should fix this. + if key != "epsilon": + config_value = f'"{config_value}"' + if count == total_params - 1: + text += f'"{str(key)}"' + ' : ' + config_value + '\n\n' + else: + text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' + text += '\n\n}' + sess = tf.InteractiveSession() + summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text)) + summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph) + text = sess.run(summary_op) + summary_writer.add_summary(text, 0) + summary_writer.flush() + summary_writer.close() + tf.reset_default_graph() \ No newline at end of file diff --git a/src/vae_tf/models.py b/src/vae_tf/models.py index d7dd073..6c8e8f6 100644 --- a/src/vae_tf/models.py +++ b/src/vae_tf/models.py @@ -75,6 +75,8 @@ def __init__(self, self.recompute_grad = recompute_grad self.bf16 = use_bf16 + self.n_hid = convblocks[-1][1] + assert math.log2(stack_factor).is_integer() # maybe you don't actually need this? self.stack_factor = stack_factor @@ -109,7 +111,6 @@ def encoder_block(x, channels=channels): x = x + res_out with tf.variable_scope(f"codebook"): - self.n_hid = x.shape[-1] embedding = tf.get_variable("codebook", shape=[self.n_hid, self.num_tokens], dtype=tf.float32) if self.bf16: @@ -119,9 +120,8 @@ def encoder_block(x, channels=channels): return output - def decoder(self, x): - with tf.variable_scope(f"codebook", reuse=True): + with tf.variable_scope(f"codebook", reuse=tf.AUTO_REUSE): embedding = tf.get_variable("codebook", shape=[self.n_hid, self.num_tokens], dtype=tf.float32) x = tf.matmul(x, embedding, transpose_b=True) @@ -162,6 +162,22 @@ def decoder_block(x, channels=channels): return x + def decode(self, input_indices): + batch, seqlen = input_indices.shape + + print(f"seqlen {seqlen}") + print(f"side expected {self.W // (2 ** len(self.convblocks))}") + + assert seqlen == (self.W // (2 ** len(self.convblocks))) * (self.H // (2 ** len(self.convblocks))) + + input_onehot = tf.one_hot(input_indices, self.num_tokens) + input_reshaped = tf.reshape(input_onehot, [batch, + self.H // (2 ** len(self.convblocks)), + self.W // (2 ** len(self.convblocks)), + self.num_tokens]) # NHWC + + return self.decoder(input_reshaped) + def forward(self, features, return_recon_loss=False, return_logits=False, hard_gumbel=True, temperature=1.): if isinstance(features, dict): img = features["inputs"] diff --git a/test.py b/test.py new file mode 100644 index 0000000..90efc0e --- /dev/null +++ b/test.py @@ -0,0 +1,106 @@ +import pytest +import traceback +import logging +from collections import defaultdict +from contextlib import contextmanager + +import tensorflow as tf +tf.compat.v1.enable_eager_execution() +import mesh_tensorflow as mtf +from mesh_tensorflow import placement_mesh_impl + +from src.dalle_mtf.models import DALLE +from src.dalle_mtf.sample import sample_autoregressive + +# helper functions + +@contextmanager +def not_raises(exception): + try: + yield + except exception: + logging.error(traceback.format_exc()) + raise pytest.fail("DID RAISE {0}".format(exception)) + +# tests + +def test_model(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + mesh = mesh, + batch_size = 1, + n_embd = 16, + n_heads = 2, + bf_16 = False + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + text_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + image_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + + features = { + 'text_inputs': mtf.slice(text_inputs, 0, model.text_seq_len, sequence_dim.name), + 'image_inputs': mtf.slice(image_inputs, 0, model.image_seq_len, sequence_dim.name) + } + + with not_raises(Exception): + loss, loss_batch, logits = model.forward(features, return_loss = True, return_logits = True) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + logits = lowering.export_to_tf_tensor(logits) + +def test_sampling(): + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "my_mesh") + + model = DALLE( + mesh = mesh, + batch_size = 1, + text_seq_len = 1, + image_seq_len = 4, + n_embd = 16, + n_heads = 2, + bf_16 = False, + unique_pad_tokens = True + ) + + batch_dim = model.dimensions["batch_dim"] + sequence_dim = model.dimensions["total_seq_dim"] + + text_inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + image_inputs = mtf.zeros(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) + + inputs = { + 'text_inputs': mtf.slice(text_inputs, 0, model.text_seq_len, sequence_dim.name), + 'image_inputs': mtf.slice(image_inputs, 0, model.image_seq_len, sequence_dim.name) + } + + with not_raises(Exception): + cached_samples = sample_autoregressive( + inputs, + model, + variable_dtype=mtf.VariableDType(), + max_steps = sequence_dim.size, + cached = True + ) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + cached_samples = lowering.export_to_tf_tensor(cached_samples) + + noncached_samples = sample_autoregressive( + inputs, + model, + variable_dtype=mtf.VariableDType(), + max_steps = model.image_seq_len, + cached = False + ) + + mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + noncached_samples = lowering.export_to_tf_tensor(noncached_samples) diff --git a/train_dalle.py b/train_dalle.py index e5c4428..1a3ae10 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -6,7 +6,7 @@ import argparse from src.utils import * from src.model_fns import dalle_model_fn -from src.input_fns import dalle_input_fn +from src.input_fns import dalle_input_fn, pred_input, pred_output from src.data import get_tokenizer def parse_args(): @@ -18,6 +18,8 @@ def parse_args(): parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.") parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " "starts a new training run") + parser.add_argument('--predict', action='store_true', help='run model in predict mode') + parser.add_argument('--prompt', type=str, default='face') args = parser.parse_args() assert args.model is not None, "Model must be set" return args @@ -29,6 +31,7 @@ def main(): logging = setup_logging(args) params = fetch_model_params(args.model) params["vae_params"] = fetch_model_params(params["vae_model"]) + save_config(params, params['model_dir']) assert params["model_type"].lower() == "dalle", f'model_type {params["model_type"]} not recognized' # Confirm deletion of checkpoint files if --new flag is set @@ -46,6 +49,8 @@ def main(): params["gpu_ids"] = args.gpu_ids tokenizer = get_tokenizer(params["tokenizer"]) assert len(tokenizer) == params["text_vocab_size"], f"tokenizer vocab size {len(tokenizer)} must equal model vocab size {params['text_vocab_size']}" + params['image_seq_len'] = get_image_seq_len(params) + params['total_seq_len'] = params['image_seq_len'] + params['text_seq_len'] params["padding_id"] = tokenizer.encode(tokenizer.pad_token)[0] # Set up TPUs and Estimator if args.tpu == "colab": @@ -76,19 +81,27 @@ def main(): eval_batch_size=params["eval_batch_size"], predict_batch_size=params["predict_batch_size"], params=params) + if args.predict: + # Predict + pred_input_fn = partial(pred_input, tokenizer=tokenizer, prompt=args.prompt) + predictions = estimator.predict(input_fn=pred_input_fn) + logging.info("Predictions generated") + pred_output(predictions, 'test') + return has_predict_or_eval_steps = params["predict_steps"] > 0 or params["eval_steps"] > 0 if has_predict_or_eval_steps: # Eval and train - stop and predict and/or eval every checkpoint while current_step < params["train_steps"]: - next_checkpoint = min(current_step + args.steps_per_checkpoint, params["train_steps"]) + next_checkpoint = min(current_step + params["steps_per_checkpoint"], params["train_steps"]) estimator.train(input_fn=partial(dalle_input_fn, eval=False), max_steps=next_checkpoint) current_step = next_checkpoint if params["predict_steps"] > 0: raise NotImplementedError if params["eval_steps"] > 0: - raise NotImplementedError + estimator.evaluate(input_fn=partial(dalle_input_fn, eval=True), + steps=params["eval_steps"]) return else: # Else, just train @@ -98,6 +111,7 @@ def main(): max_steps=params["train_steps"]) + if __name__ == "__main__": tf.disable_v2_behavior() main() diff --git a/train_vae.py b/train_vae.py index 835cbbc..053cc79 100644 --- a/train_vae.py +++ b/train_vae.py @@ -28,6 +28,7 @@ def main(): args = parse_args() logging = setup_logging(args) params = fetch_model_params(args.model) + save_config(params, params['model_dir']) assert params["model_type"].lower() == "vae", f'model_type {params["model_type"]} not recognized' # get current step