Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Slides

- [ismir2019-tutorial-slides.pdf](pdf/ismir2019-tutorial-slides.pdf)
- [ismir2019-tutorial-slides.pdf](docs/pdf/ismir2019-tutorial-slides.pdf)

## Notebooks

Expand Down
24 changes: 12 additions & 12 deletions gan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"id": "CJ_cVBuhk13r"
},
"source": [
"!pip install -q torch torchvision matplotlib tqdm livelossplot"
"!pip install -q livelossplot"
],
"execution_count": 1,
"outputs": []
Expand All @@ -79,13 +79,12 @@
},
"source": [
"from IPython.display import clear_output\n",
"from ipywidgets import interact, IntSlider\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torchvision\n",
"from tqdm import tqdm\n",
"from tqdm.notebook import tqdm\n",
"from livelossplot import PlotLosses\n",
"from livelossplot.outputs import MatplotlibPlot"
],
Expand Down Expand Up @@ -186,6 +185,7 @@
" \"\"\"Convert images to vectors.\"\"\"\n",
" return images.view(images.size(0), 784)\n",
"\n",
"\n",
"def vectors_to_images(vectors):\n",
" \"\"\"Convert vectors to images.\"\"\"\n",
" return vectors.view(vectors.size(0), 1, 28, 28)"
Expand Down Expand Up @@ -276,8 +276,8 @@
"source": [
"def compute_gradient_penalty(discriminator, real_samples, fake_samples):\n",
" \"\"\"Compute the gradient penalty for regularization. Intuitively, the\n",
" gradient penalty help stablize the magnitude of the gradients that the\n",
" discriminator provides to the generator, and thus help stablize the training\n",
" gradient penalty help stabilize the magnitude of the gradients that the\n",
" discriminator provides to the generator, and thus help stabilize the training\n",
" of the generator.\"\"\"\n",
" # Get random interpolations between real and fake samples\n",
" alpha = torch.rand(real_samples.size(0), 1).cuda()\n",
Expand Down Expand Up @@ -318,7 +318,7 @@
" if torch.cuda.is_available():\n",
" real_samples = real_samples.cuda()\n",
" latent = latent.cuda()\n",
" \n",
"\n",
" # === Train the discriminator ===\n",
" # Reset cached gradients to zero\n",
" d_optimizer.zero_grad()\n",
Expand All @@ -328,7 +328,7 @@
" d_loss_real = torch.mean(torch.nn.functional.relu(1. - prediction_real))\n",
" # Backpropagate the gradients\n",
" d_loss_real.backward()\n",
" \n",
"\n",
" # Generate fake samples with the generator\n",
" fake_samples = generator(latent)\n",
" # Get discriminator outputs for the fake samples\n",
Expand All @@ -346,7 +346,7 @@
"\n",
" # Update the weights\n",
" d_optimizer.step()\n",
" \n",
"\n",
" # === Train the generator ===\n",
" # Reset cached gradients to zero\n",
" g_optimizer.zero_grad()\n",
Expand Down Expand Up @@ -415,7 +415,7 @@
"history_samples = {}\n",
"\n",
"# Create a LiveLoss logger instance for monitoring\n",
"liveloss = PlotLosses(outputs=[MatplotlibPlot(cell_size=(6,2))])\n",
"liveloss = PlotLosses(outputs=[MatplotlibPlot(cell_size=(6, 2))])\n",
"\n",
"# Initialize step\n",
"step = 0"
Expand Down Expand Up @@ -480,7 +480,7 @@
" progress_bar.set_description_str(\n",
" f\"(d_loss={d_loss: 8.6f}, g_loss={g_loss: 8.6f})\"\n",
" )\n",
" \n",
"\n",
" if step % sample_interval == 0:\n",
" # Get generated samples\n",
" samples = vectors_to_images(generator(sample_latent))\n",
Expand All @@ -494,13 +494,13 @@
" clear_output(True)\n",
" if step > 0:\n",
" liveloss.send()\n",
" \n",
"\n",
" # Display generated samples\n",
" plt.figure(figsize=(15, 3))\n",
" plt.imshow(samples, cmap='Greys', vmin=0, vmax=1)\n",
" plt.axis('off')\n",
" plt.show()\n",
" \n",
"\n",
" step += 1\n",
" progress_bar.update(1)\n",
" if step >= n_steps:\n",
Expand Down
35 changes: 11 additions & 24 deletions musegan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"outputId": "cd1aacf3-c361-42ca-9a12-3f5d3475e3c5"
},
"source": [
"!pip3 install torch matplotlib tqdm livelossplot gdown \"pypianoroll>=1.0.2\""
"!pip install livelossplot pypianoroll"
],
"execution_count": 1,
"outputs": [
Expand Down Expand Up @@ -133,10 +133,8 @@
},
"source": [
"from IPython.display import clear_output\n",
"from ipywidgets import interact, IntSlider\n",
"\n",
"import os\n",
"import os.path\n",
"import random\n",
"from pathlib import Path\n",
"\n",
Expand All @@ -145,7 +143,7 @@
"import torch\n",
"import pypianoroll\n",
"from pypianoroll import Multitrack, Track\n",
"from tqdm import tqdm\n",
"from tqdm.notebook import tqdm\n",
"from livelossplot import PlotLosses\n",
"from livelossplot.outputs import MatplotlibPlot"
],
Expand All @@ -171,7 +169,7 @@
"n_tracks = 5 # number of tracks\n",
"n_pitches = 72 # number of pitches\n",
"lowest_pitch = 24 # MIDI note number of the lowest pitch\n",
"n_samples_per_song = 8 # number of samples to extract from each song in the datset\n",
"n_samples_per_song = 8 # number of samples to extract from each song in the dataset\n",
"n_measures = 4 # number of measures per sample\n",
"beat_resolution = 4 # temporal resolution of a beat (in timestep)\n",
"programs = [0, 0, 25, 33, 48] # program number for each track\n",
Expand Down Expand Up @@ -205,7 +203,7 @@
")\n",
"assert len(programs) == len(is_drums) and len(programs) == len(track_names), (\n",
" \"Lengths of programs, is_drums and track_names must be the same.\"\n",
") "
")"
],
"execution_count": 4,
"outputs": []
Expand Down Expand Up @@ -322,13 +320,13 @@
"outputId": "6ca783a5-0fe2-4d39-ec75-6f2380dae173"
},
"source": [
"song_dir = dataset_root / msd_id_to_dirs('TREVDFX128E07859E0') # 'TRQAOWZ128F93000A4', 'TREVDFX128E07859E0'\n",
"song_dir = dataset_root / msd_id_to_dirs('TREVDFX128E07859E0') # 'TRQAOWZ128F93000A4', 'TREVDFX128E07859E0'\n",
"multitrack = pypianoroll.load(song_dir / os.listdir(song_dir)[0])\n",
"multitrack.trim(end=12 * 96)\n",
"axs = multitrack.plot()\n",
"plt.gcf().set_size_inches((16, 8))\n",
"for ax in axs:\n",
" for x in range(96, 12 * 96, 96): \n",
" for x in range(96, 12 * 96, 96):\n",
" ax.axvline(x - 0.5, color='k', linestyle='-', linewidth=1)\n",
"plt.show()"
],
Expand Down Expand Up @@ -830,7 +828,7 @@
"history_samples = {}\n",
"\n",
"# Create a LiveLoss logger instance for monitoring\n",
"liveloss = PlotLosses(outputs=[MatplotlibPlot(cell_size=(6,2))])\n",
"liveloss = PlotLosses(outputs=[MatplotlibPlot(cell_size=(6, 2))])\n",
"\n",
"# Initialize step\n",
"step = 0"
Expand Down Expand Up @@ -886,11 +884,11 @@
" running_d_loss, running_g_loss = 0.0, 0.0\n",
" liveloss.update({'negative_critic_loss': -running_d_loss})\n",
" # liveloss.update({'d_loss': running_d_loss, 'g_loss': running_g_loss})\n",
" \n",
"\n",
" # Update losses to progress bar\n",
" progress_bar.set_description_str(\n",
" \"(d_loss={: 8.6f}, g_loss={: 8.6f})\".format(d_loss, g_loss))\n",
" \n",
"\n",
" if step % sample_interval == 0:\n",
" # Get generated samples\n",
" generator.eval()\n",
Expand All @@ -901,7 +899,7 @@
" clear_output(True)\n",
" if step > 0:\n",
" liveloss.send()\n",
" \n",
"\n",
" # Display generated samples\n",
" samples = samples.transpose(1, 0, 2, 3).reshape(n_tracks, -1, n_pitches)\n",
" tracks = []\n",
Expand Down Expand Up @@ -938,7 +936,7 @@
" else:\n",
" ax.axvline(x - 0.5, color='k', linestyle='-', linewidth=1)\n",
" plt.show()\n",
" \n",
"\n",
" step += 1\n",
" progress_bar.update(1)\n",
" if step >= n_steps:\n",
Expand Down Expand Up @@ -1131,17 +1129,6 @@
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aeJpOeUXjdBu"
},
"source": [
""
],
"execution_count": 21,
"outputs": []
}
]
}