Skip to content

Commit c1689e1

Browse files
committed
some scaffold, copy in essentials (attention, what else)
1 parent 8b07074 commit c1689e1

File tree

4 files changed

+190
-0
lines changed

4 files changed

+190
-0
lines changed

Diff for: .github/workflows/python-publish.yml

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
3+
# This workflow will upload a Python Package using Twine when a release is created
4+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5+
6+
# This workflow uses actions that are not certified by GitHub.
7+
# They are provided by a third-party and are governed by
8+
# separate terms of service, privacy policy, and support
9+
# documentation.
10+
11+
name: Upload Python Package
12+
13+
on:
14+
release:
15+
types: [published]
16+
17+
jobs:
18+
deploy:
19+
20+
runs-on: ubuntu-latest
21+
22+
steps:
23+
- uses: actions/checkout@v2
24+
- name: Set up Python
25+
uses: actions/setup-python@v2
26+
with:
27+
python-version: '3.x'
28+
- name: Install dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install build
32+
- name: Build package
33+
run: python -m build
34+
- name: Publish package
35+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36+
with:
37+
user: __token__
38+
password: ${{ secrets.PYPI_API_TOKEN }}

Diff for: musiclm_pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from musiclm_pytorch.musiclm_pytorch import MuLaN

Diff for: musiclm_pytorch/musiclm_pytorch.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch import nn, einsum
4+
5+
from einops import rearrange, repeat, reduce
6+
7+
# functions
8+
9+
def exists(val):
10+
return val is not None
11+
12+
# biasless layernorm
13+
14+
class LayerNorm(nn.Module):
15+
def __init__(self, dim):
16+
super().__init__()
17+
self.gamma = nn.Parameter(torch.ones(dim))
18+
self.register_buffer("beta", torch.zeros(dim))
19+
20+
def forward(self, x):
21+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
22+
23+
# attention
24+
25+
class Attention(nn.Module):
26+
def __init__(
27+
self,
28+
dim,
29+
causal = False,
30+
dim_head = 64,
31+
heads = 8,
32+
num_null_kv = 0,
33+
dropout = 0.1
34+
):
35+
super().__init__()
36+
self.heads = heads
37+
self.scale = dim_head ** -0.5
38+
self.causal = causal
39+
inner_dim = dim_head * heads
40+
41+
self.norm = LayerNorm(dim)
42+
43+
self.attn_dropout = nn.Dropout(dropout)
44+
45+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
46+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
47+
48+
self.to_out = nn.Sequential(
49+
nn.Linear(inner_dim, dim, bias = False),
50+
nn.Dropout(dropout)
51+
)
52+
53+
def forward(
54+
self,
55+
x,
56+
context = None,
57+
mask = None,
58+
attn_bias = None,
59+
prefix_context = None,
60+
prefix_context_mask = None
61+
):
62+
b, n, _, device = *x.shape, x.device
63+
64+
# prenorm
65+
66+
x = self.norm(x)
67+
68+
# project for queries, keys, values
69+
70+
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
71+
72+
# split for multi-headed attention
73+
74+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
75+
76+
q = q * self.scale
77+
78+
# similarities
79+
80+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
81+
82+
if exists(mask):
83+
mask = rearrange(mask, 'b j -> b 1 1 j')
84+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
85+
86+
if self.causal:
87+
i, j = sim.shape[-2:]
88+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
89+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
90+
91+
# attention
92+
93+
attn = sim.softmax(dim = -1)
94+
attn = self.attn_dropout(attn)
95+
96+
# aggregate
97+
98+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
99+
100+
# merge heads
101+
102+
out = rearrange(out, 'b h n d -> b n (h d)')
103+
return self.to_out(out)
104+
105+
# main classes
106+
107+
class MuLaN(nn.Module):
108+
def __init__(self):
109+
super().__init__()
110+
111+
def forward(self, x):
112+
return x
113+
114+
class MusicLM(nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
def forward(self, x):
119+
return x

Diff for: setup.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from setuptools import setup, find_packages
2+
3+
setup(
4+
name = 'musiclm-pytorch',
5+
packages = find_packages(exclude=[]),
6+
version = '0.0.1',
7+
license='MIT',
8+
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
9+
author = 'Phil Wang',
10+
author_email = '[email protected]',
11+
long_description_content_type = 'text/markdown',
12+
url = 'https://github.com/lucidrains/musiclm-pytorch',
13+
keywords = [
14+
'artificial intelligence',
15+
'deep learning',
16+
'transformers',
17+
'attention mechanism',
18+
'text to music',
19+
'contrastive learning'
20+
],
21+
install_requires=[
22+
'einops>=0.4',
23+
'torch>=1.6',
24+
],
25+
classifiers=[
26+
'Development Status :: 4 - Beta',
27+
'Intended Audience :: Developers',
28+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
29+
'License :: OSI Approved :: MIT License',
30+
'Programming Language :: Python :: 3.6',
31+
],
32+
)

0 commit comments

Comments
 (0)