diff --git a/how-to-guides/12-lr-scheduler.ipynb b/how-to-guides/12-lr-scheduler.ipynb new file mode 100644 index 0000000..929f53b --- /dev/null +++ b/how-to-guides/12-lr-scheduler.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "# How to use LR-Schedulers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This how-to guide demonstrates how we can use LR-Schedulers to adjust the learning rate of a model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "! pip install pytorch-ignite" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.optim as optim\n", + "from torch.optim.lr_scheduler import ExponentialLR\n", + "\n", + "from ignite.engine import Engine, Events\n", + "from ignite.handlers import create_lr_scheduler_with_warmup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a `Dummy Trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def train_step(e, b):\n", + " print(trainer.state.epoch, trainer.state.iteration, \" | \", optimizer.param_groups[0][\"lr\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Engine(train_step)\n", + "optimizer = optim.SGD([torch.tensor([0.1])], lr=0.1234)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initiate a `LRScheduler`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.5)\n", + "\n", + "data = [0] * 8\n", + "epoch_length = len(data)\n", + "warmup_duration = 5\n", + "scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,\n", + " warmup_start_value=0.0,\n", + " warmup_duration=warmup_duration)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Trigger LR-Scheduler:\n", + "\n", + " - Step 1: Trigger scheduler on interation_started events before reaching warm-up.\n", + " - Step 2: Trigger scheduler on epoch_started events after the warm-up. \n", + "\n", + "Note: Epochs are 1-based, thus we do 1 + warmup_duration / epoch_length \n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "combined_events = Events.ITERATION_STARTED(event_filter=lambda _, __: trainer.state.iteration <= warmup_duration)\n", + "combined_events |= Events.EPOCH_STARTED(event_filter=lambda _, __: trainer.state.epoch > 1 + warmup_duration / epoch_length)\n", + "trainer.add_event_handler(combined_events, scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Execute Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 1 | 0.0\n", + "1 2 | 0.03085\n", + "1 3 | 0.0617\n", + "1 4 | 0.09255\n", + "1 5 | 0.1234\n", + "1 6 | 0.1234\n", + "1 7 | 0.1234\n", + "1 8 | 0.1234\n", + "2 9 | 0.0617\n", + "2 10 | 0.0617\n", + "2 11 | 0.0617\n", + "2 12 | 0.0617\n", + "2 13 | 0.0617\n", + "2 14 | 0.0617\n", + "2 15 | 0.0617\n", + "2 16 | 0.0617\n", + "3 17 | 0.03085\n", + "3 18 | 0.03085\n", + "3 19 | 0.03085\n", + "3 20 | 0.03085\n", + "3 21 | 0.03085\n", + "3 22 | 0.03085\n", + "3 23 | 0.03085\n", + "3 24 | 0.03085\n", + "4 25 | 0.015425\n", + "4 26 | 0.015425\n", + "4 27 | 0.015425\n", + "4 28 | 0.015425\n", + "4 29 | 0.015425\n", + "4 30 | 0.015425\n", + "4 31 | 0.015425\n", + "4 32 | 0.015425\n", + "5 33 | 0.0077125\n", + "5 34 | 0.0077125\n", + "5 35 | 0.0077125\n", + "5 36 | 0.0077125\n", + "5 37 | 0.0077125\n", + "5 38 | 0.0077125\n", + "5 39 | 0.0077125\n", + "5 40 | 0.0077125\n", + "6 41 | 0.00385625\n", + "6 42 | 0.00385625\n", + "6 43 | 0.00385625\n", + "6 44 | 0.00385625\n", + "6 45 | 0.00385625\n", + "6 46 | 0.00385625\n", + "6 47 | 0.00385625\n", + "6 48 | 0.00385625\n", + "7 49 | 0.001928125\n", + "7 50 | 0.001928125\n", + "7 51 | 0.001928125\n", + "7 52 | 0.001928125\n", + "7 53 | 0.001928125\n", + "7 54 | 0.001928125\n", + "7 55 | 0.001928125\n", + "7 56 | 0.001928125\n", + "8 57 | 0.0009640625\n", + "8 58 | 0.0009640625\n", + "8 59 | 0.0009640625\n", + "8 60 | 0.0009640625\n", + "8 61 | 0.0009640625\n", + "8 62 | 0.0009640625\n", + "8 63 | 0.0009640625\n", + "8 64 | 0.0009640625\n", + "9 65 | 0.00048203125\n", + "9 66 | 0.00048203125\n", + "9 67 | 0.00048203125\n", + "9 68 | 0.00048203125\n", + "9 69 | 0.00048203125\n", + "9 70 | 0.00048203125\n", + "9 71 | 0.00048203125\n", + "9 72 | 0.00048203125\n", + "10 73 | 0.000241015625\n", + "10 74 | 0.000241015625\n", + "10 75 | 0.000241015625\n", + "10 76 | 0.000241015625\n", + "10 77 | 0.000241015625\n", + "10 78 | 0.000241015625\n", + "10 79 | 0.000241015625\n", + "10 80 | 0.000241015625\n" + ] + }, + { + "data": { + "text/plain": [ + "State:\n", + "\titeration: 80\n", + "\tepoch: 10\n", + "\tepoch_length: 8\n", + "\tmax_epochs: 10\n", + "\toutput: \n", + "\tbatch: 0\n", + "\tmetrics: \n", + "\tdataloader: \n", + "\tseed: \n", + "\ttimes: " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.run(data, max_epochs=10)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}