Skip to content

Commit

Permalink
Update project build settings
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-tow committed Dec 28, 2022
1 parent c793bcb commit be43f1d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 37 deletions.
40 changes: 17 additions & 23 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
[build-system]
requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["text_sed"]
include = ["text_sed*"]

[tool.black]
line-length = 88

[tool.isort]
profile = "black"

[project]
name = "text-sed"
version = "0.0.1"
requires-python = ">=3.8"
description = "A PyTorch implementation of Self-conditioned Embedding Diffusion"
authors = [{name = "Jonathan Tow", email = "[email protected]"}]
dynamic = ["version"]
requires-python = ">=3.8"
keywords = ["nlp", "pytorch", "machine-learning", "text-generation"]
classifiers = [
"Development Status :: 3 - Alpha",
Expand All @@ -31,13 +18,20 @@ dependencies = [
"transformers",
]

[[project.authors]]
author = "Jonathan Tow"
email = "[email protected]"

[project.license]
text = "MIT"

[project.optional-dependencies]
dev = ["black", "flake8", "isort"]
train = ["datasets", "omegaconf", "wandb", "tqdm"]

[build-system]
requires = ["setuptools>=64", "setuptools_scm[toml]>=6.2", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["text_sed"]
include = ["text_sed*"]

[tool.black]
line-length = 101

[tool.isort]
profile = "black"
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from setuptools import setup

setup()
setup(
name="text-sed",
version="0.0.1",
packages=["text_sed"],
)
24 changes: 11 additions & 13 deletions text_sed/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def ddpm_step(


def corrupt(
inputs: Tensor, # x₀
time: Tensor, # t
inputs: Tensor, # x₀
time: Tensor, # t
schedule: Callable, # ᾱ schedule
) -> Tensor:
"""q sampler: q(xₜ | xₒ) ~ N(xₒ * √ᾱₜ, (1 - ᾱₜ)I)
Expand All @@ -200,6 +200,7 @@ def __init__(
self,
model: nn.Module,
embed_mat: NamedTensor["vocab", "embed"],
*,
use_self_cond: bool = True,
noise_schedule: Callable = get_noise_schedule("cosine"),
bottleneck_dim: Optional[int] = None,
Expand Down Expand Up @@ -300,6 +301,7 @@ def forward(
)
cond_mask: NamedTensor["batch", "pos", "1"] = cond_mask[..., None]
# Remove padding positions from the conditioning/infilling masks
# TODO (jon-tow): We shouldn't need to do this - remove this once verified
cond_mask = cond_mask * attention_mask
infill_mask = (1 - cond_mask) * attention_mask

Expand Down Expand Up @@ -372,8 +374,7 @@ def generate(
use_clamp: Whether to clamp predicted embeddings to the range
[-1, 1] before each diffusion sampling step.
"""
default_cond_mask = torch.zeros(shape[:-1], device=device)[..., None]
cond_mask = utils.default(cond_mask, default_cond_mask).bool()
cond_mask = utils.default(cond_mask, torch.zeros(shape[:-1], device=device)[..., None]).bool()
infill_mask = (~cond_mask).float()

# Sample start embedding from the normal prior eₜ ~ qₜ
Expand All @@ -382,15 +383,12 @@ def generate(
for step in range(num_steps):
# Get time for current and next states. NOTE: (1 - ...) to process in reverse
time_now = torch.tensor([1 - step / num_steps], device=device)
time_next = torch.tensor(
[
torch.maximum(
torch.tensor(1 - (step + 1 + time_delta) / num_steps),
torch.tensor(0.0),
)
],
device=device,
)
time_next = torch.tensor([
torch.maximum(
torch.tensor(1 - (step + 1 + time_delta) / num_steps),
torch.tensor(0.0),
)
], device=device)

# if (
# guide_scale is not None and cond_mask is None
Expand Down

0 comments on commit be43f1d

Please sign in to comment.