Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
99d11e1
add basic sampling code
Jan 13, 2021
8bb3322
Merge remote-tracking branch 'origin/main' into sample
Jan 13, 2021
b9fd039
add prediction input / output fns
Jan 13, 2021
cf76c6c
get sample_autoregressive working
ConnorJL Jan 14, 2021
c7ff6c4
truncate text tokens properly
ConnorJL Jan 14, 2021
4d51fd9
log model params to tensorboard
ConnorJL Jan 14, 2021
346871f
add vae decoding and write to jpeg
kingoflolz Jan 14, 2021
d13c330
unshift image outputs at decode time
kingoflolz Jan 14, 2021
f8a7449
dirty hack to use vae decoder params when training dalle
kingoflolz Jan 14, 2021
ff56d12
Move initialize_vae_weights to after lowering
leogao2 Jan 17, 2021
4c4e0e0
fix vae checkpoint load in training
ConnorJL Jan 17, 2021
2c14bde
fix parameter count logging
ConnorJL Jan 18, 2021
130c26e
fix image vocab size
ConnorJL Jan 18, 2021
4652ef2
revert to separate embeddings for image and text
ConnorJL Jan 19, 2021
67247cf
Fix masking
leogao2 Jan 19, 2021
a0a2828
Ignore text tokens in loss computation
leogao2 Jan 20, 2021
ca23b85
Fix slicing for mtf
leogao2 Jan 20, 2021
e126b79
Fix sampling
leogao2 Jan 20, 2021
ddeb74d
Implement incremental logits mask
leogao2 Jan 20, 2021
e7eb459
Fix typo
leogao2 Jan 20, 2021
29f3006
revert changes to sample.py
ConnorJL Jan 21, 2021
b69ff72
add mask to bias op
ConnorJL Jan 21, 2021
7e3b6ff
update mask. (still not working :( )
ConnorJL Jan 21, 2021
34d5326
mask changes
ConnorJL Jan 31, 2021
05dec26
fix label shifting
ConnorJL Feb 1, 2021
6c39bdf
add weight decay Adam
ConnorJL Feb 1, 2021
69d7ef9
add eval steps
ConnorJL Feb 1, 2021
2ec46be
add slow sampling
Apr 4, 2021
321b718
add slow sampling
Apr 4, 2021
d564853
add tests
lucidrains Apr 4, 2021
a0d01b9
switch to using <bos> instead of <eos>
lucidrains Apr 4, 2021
58c4d4e
separate logits for text and images
lucidrains Apr 4, 2021
9410ba9
remove offsetted image, since it is no longer needed
lucidrains Apr 4, 2021
f13a567
add unique pad tokens feature, hidden behind a feature flag
lucidrains Apr 4, 2021
dafcb99
Merge branch 'separate_embeddings' of
ConnorJL Apr 4, 2021
0f774e1
remove logits mask code
lucidrains Apr 4, 2021
2972665
fix syntax
lucidrains Apr 4, 2021
2c081f2
variable scope
lucidrains Apr 4, 2021
6a37d0f
remove stop_at_tokens, since it wont be used, to avoid confusion
lucidrains Apr 4, 2021
c60bf54
more sample cleanup
lucidrains Apr 4, 2021
02a51f9
make sure one can sample the non-cached way
lucidrains Apr 4, 2021
beb1c60
fix sampling for cached version (potentially)
lucidrains Apr 4, 2021
c085faf
change max steps to image seq len, since it is counting from the star…
lucidrains Apr 5, 2021
3ffc5d1
fix initial position at 0
lucidrains Apr 5, 2021
ef7864c
make sure axial positional embedding is shifted over by one due to <bos>
lucidrains Apr 6, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest -s test.py
10 changes: 5 additions & 5 deletions configs/dalle_coco.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
},
"train_batch_size": 128,
"eval_batch_size": 128,
"predict_batch_size": 128,
"predict_batch_size": 16,
"steps_per_checkpoint": 5000,
"iterations": 1000,
"train_steps": 100000,
"predict_steps": 0,
"eval_steps": 0,
"n_channels": 3,
"bf_16": false,
"bf_16": true,
"recompute_grad": true,
"lr": 0.0001,
"model_path": "gs://neo-models/dalle_coco/",
"model_path": "gs://neo-models/dalle_coco_sample/",
"mesh_shape": "data:16,model:2",
"layout": "batch_dim:data",
"layout": "batch_dim:data,embed_dim:model",
"n_embd": 1024,
"text_vocab_size": 50258,
"image_vocab_size": 512,
"image_vocab_size": 2048,
"text_seq_len": 256,
"n_layers": 12,
"n_heads": 8,
Expand Down
3 changes: 2 additions & 1 deletion src/dalle_mtf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .models import DALLE, DiscreteVAE
from .models import DALLE, DiscreteVAE
from .sample import sample_autoregressive
230 changes: 184 additions & 46 deletions src/dalle_mtf/models.py

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions src/dalle_mtf/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,21 @@ def get_variable_dtype(bf_16=True):
return mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16)
else:
return mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)

def expand_tile(value, newdim, axis=0):
"""Add a new axis of given size."""
new_shape = value.shape.dims
new_shape.insert(axis, newdim)
return mtf.broadcast(value, new_shape) # shape.dims gets us a list which we need in order to concat

def mask_to_bias(visible, dtype):
"""Convert a boolean visibility mask to an attention bias.
The returned Tensor has large negative values in positions where
visible=False.
Args:
visible: a boolean Tensor
dtype: a dtype
Returns:
a Tensor with the given dtype and the same shape as "visible"
"""
return mtf.cast(mtf.logical_not(visible), dtype) * -1e9
167 changes: 167 additions & 0 deletions src/dalle_mtf/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
import mesh_tensorflow.transformer as mtf_transformer


def sample_autoregressive(inputs,
model,
max_steps=None,
temperature=0.9,
padding_id = 0,
variable_dtype=mtf.VariableDType(tf.float32),
sampling_keep_top_k=-1,
cached=True
):
"""Sample randomly one token at a time.

The partial_sequences represent partial sequences to be continued. The
first tokens of each sequence are nonzero representing the given partial
sequences and the last tokens of each sequence are zeros, representing what
needs to be filled in.

If there are no partial sequences (you want to sample from the beginning),
then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and

Args:
inputs: an input dictionary containing 'text_inputs' and 'image_inputs',
model: DALL-E model
stop_at_token: an optional integer eos id. Stop when we produce it.
max_steps: an optional integer, the max number of steps to decode.
temperature: an optional floating point value between 0.0 and 1.0 0.0
means argmax, 1.0 means sample according to predicted distribution.
variable_dtype: a mtf.VariableDType
decoding, one per each input layer + the embedding layer
sampling_keep_top_k: an integer - if not -1, only sample from the top k
logits.

Returns:
a Tensor with shape [<batch_dims>, length_dim]
"""

# with dalle, inputs will be a text sequence of len 256, then the rest image tokens.
# the parts we want to fill in will be <|pad_token|>, which we should assign in the input

batch_dims = model.dimensions["batch_dim"]
length_dim = model.dimensions["total_seq_dim"]
image_seq_dim = model.dimensions['image_sequence_dim']

image_inputs = inputs['image_inputs']
text_inputs = inputs['text_inputs']

# Gets position (in image inputs) where zero padding starts
initial_position = mtf.zeros(text_inputs.mesh, mtf.Shape((batch_dims,)), dtype=tf.int32)

# initial_position += model.dimensions['text_seq_dim'].size

length_range = mtf.range(image_inputs.mesh, image_seq_dim, tf.int32)

# one step of sampling fn

def sample_step(logits, ids, position, incremental):
nonlocal sampling_keep_top_k
# By default, do top_k sampling of 0.9
if sampling_keep_top_k == -2:
sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

if sampling_keep_top_k != -1:
if sampling_keep_top_k <= 0:
raise ValueError("sampling_keep_top_k must either be -1 or positive.")
k_largest = mtf.nth_largest_element(
logits, n=sampling_keep_top_k,
reduced_dim=model.dimensions['image_vocab_dim'])
logits = mtf.where(mtf.less_equal(logits, k_largest),
mtf.ones_like(logits) * -1e6, logits)

# temperature sampling
ids_this_step = mtf.sample_with_temperature(
logits, model.dimensions['image_vocab_dim'], temperature)

# reshape & assign results
if incremental:
ids_this_step = mtf.reshape(ids_this_step, ([batch_dims]))
else:
ids_this_step = mtf.shift(ids_this_step, offset=1, dim=image_seq_dim, wrap=False)

one_hot = mtf.one_hot(position, image_seq_dim, dtype=tf.int32)
one_new_id = ids_this_step * one_hot
new_ids = (1 - one_hot) * ids + one_new_id
new_position = position + 1
return [new_position, new_ids]

# Builds context to pass around internally
# The 'first part' context records initial states of k / v / x
if cached:
context_first_part = mtf_transformer.transformer.Context(
model=None,
mesh=image_inputs.mesh,
batch_dims=batch_dims,
length_dim=image_seq_dim,
variable_dtype=variable_dtype,
mode="first_part",
position=length_range,
position_is_default=True,
new_states=[],
initial_position=initial_position,
sequence_id=None,
constant_states=[],
inputs=inputs)
model.context = context_first_part

with tf.variable_scope('dall-e'):
logits = model.forward(inputs, return_loss=False, return_logits=True)

initial_states = context_first_part.new_states

# sample one step to get first image token and then delete logits

initial_position, image_inputs = sample_step(logits, image_inputs, initial_position, incremental = False)

del logits
else:
initial_states = []

def cond_fn(position, ids, *unused_states):
"""Should we run another loop iteration?"""
past_end = mtf.greater_equal(position, image_seq_dim.size)
if max_steps:
past_end = mtf.logical_or(
past_end, mtf.greater_equal(position - initial_position, max_steps))

is_done = past_end
all_done = mtf.reduce_all(is_done)
return mtf.logical_not(all_done)

def body_fn(position, ids, *states):
"""One step in the decode loop."""

context = mtf_transformer.transformer.Context(
model=None,
mesh=image_inputs.mesh,
batch_dims=batch_dims,
length_dim=image_seq_dim,
variable_dtype=variable_dtype,
mode="incremental",
position=position,
position_is_default=True,
states=states,
new_states=[],
initial_position=position,
sequence_id=None,
inputs=ids) if cached else None

model.is_incremental_inference = True if cached else False
model.context = context
with tf.variable_scope("dall-e", reuse=tf.AUTO_REUSE):
logits = model.forward({'image_inputs': image_inputs, 'text_inputs': (text_inputs if not cached else None)}, return_loss=False, return_logits=True)

ret = sample_step(logits, ids, position, cached)

if cached:
ret += context.new_states
return ret

while_loop_inputs = [initial_position, image_inputs] + initial_states
final_position, outputs = mtf.while_loop(
cond_fn, body_fn, while_loop_inputs)[:2]
del final_position
return outputs
34 changes: 32 additions & 2 deletions src/input_fns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import imageio
import numpy as np
import tensorflow.compat.v1 as tf

import os

def crop_center_and_resize(img, size):
s = tf.shape(img)
Expand Down Expand Up @@ -38,6 +40,32 @@ def truncate_or_pad_label(label, params):
return label


def pred_input(params, tokenizer, prompt='a cat in a hat'):
tokens = tokenizer.encode(prompt)
if len(tokens) > params["text_seq_len"]:
tf.logging.info("The length of your input prompt is longer than the model's text context length - truncating "
"input.")
tokens = tokens[len(tokens) - params["text_seq_len"]:] # TODO: left or right truncate here?
if len(tokens) < params["text_seq_len"]:
tokens = tf.pad(tokens, [[0, params["text_seq_len"] - len(tokens)]], constant_values=params["padding_id"])
t = tf.broadcast_to(tokens, [params["batch_size"], params["text_seq_len"]])
dataset = tf.data.Dataset.from_tensors(t)

def _dummy_labels(x):
return x, x

dataset = dataset.map(_dummy_labels)
return dataset


def pred_output(predictions, out_name='test', output_dir='outputs'):
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
for i, p in enumerate(predictions):
denormalize = lambda x: (((x + 1) / 2) * 255.0).astype(np.uint8)
imageio.imwrite(f"outputs/{out_name}_{i}.jpeg", denormalize(p["predictions_decoded"]))


def read_labeled_tfrecord(params):
def read_fn(example):
features = {
Expand Down Expand Up @@ -103,6 +131,7 @@ def _process_path(file_path):
dataset = configure_for_performance(dataset, params, eval)
return dataset.repeat()


def dalle_input_fn(params, eval=False):
path = params["dataset"]["train_path"] if not eval else params["dataset"]["eval_path"]
files = tf.io.gfile.glob(path)
Expand All @@ -113,7 +142,8 @@ def dalle_input_fn(params, eval=False):

if not eval:
dataset = dataset.shuffle(file_count, reshuffle_each_iteration=False)
dataset = dataset.apply(tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
parse_fn = read_labeled_tfrecord(params)
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = configure_for_performance(dataset, params, eval)
Expand Down
Loading