diff --git a/.github/workflows/ci_docs.yml b/.github/workflows/ci_docs.yml
index fc2fa37a4..eed64e91a 100644
--- a/.github/workflows/ci_docs.yml
+++ b/.github/workflows/ci_docs.yml
@@ -36,7 +36,7 @@ jobs:
python-version: 3.8
cache: pip
- - name: Install Texlive & tree
+ - name: Install tree
run: |
sudo apt-get update --fix-missing
sudo apt-get install -y cmake tree pandoc
diff --git a/.notebooks/course_UvA-DL/01-introduction-to-pytorch.yaml b/.notebooks/course_UvA-DL/01-introduction-to-pytorch.yaml
index eded2dbc4..f0c84a92e 100644
--- a/.notebooks/course_UvA-DL/01-introduction-to-pytorch.yaml
+++ b/.notebooks/course_UvA-DL/01-introduction-to-pytorch.yaml
@@ -1,10 +1,10 @@
-title: 'Tutorial 1: Introduction to PyTorch'
+title: "Tutorial 1: Introduction to PyTorch"
author: Phillip Lippe
created: 2021-08-27
updated: 2023-03-14
license: CC BY-SA
build: 1
-description: 'This tutorial will give a short introduction to PyTorch basics, and
+description: "This tutorial will give a short introduction to PyTorch basics, and
get you setup for writing your own neural networks.
This notebook is part of a lecture series on Deep Learning at the University of
@@ -12,21 +12,21 @@ description: 'This tutorial will give a short introduction to PyTorch basics, an
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- matplotlib
-- lightning>=2.0.0
+ - matplotlib
+ - lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- torchmetrics==1.2.1
-- lightning==2.3.3
-- matplotlib==3.8.4
-- urllib3==2.2.2
-- pytorch-lightning==1.5.3
-- setuptools==69.0.3
-- ipython==8.16.1
-- numpy==1.26.4
-- torch==2.0.1+cu118
-published: '2024-07-20T00:27:36.724445'
+ - torchmetrics==1.2.1
+ - lightning==2.3.3
+ - matplotlib==3.8.4
+ - urllib3==2.2.2
+ - pytorch-lightning==1.5.3
+ - setuptools==69.0.3
+ - ipython==8.16.1
+ - numpy==1.26.4
+ - torch==2.0.1+cu118
+published: "2024-07-20T00:27:36.724445"
diff --git a/.notebooks/course_UvA-DL/02-activation-functions.yaml b/.notebooks/course_UvA-DL/02-activation-functions.yaml
index b9303f899..bf2d9b0bf 100644
--- a/.notebooks/course_UvA-DL/02-activation-functions.yaml
+++ b/.notebooks/course_UvA-DL/02-activation-functions.yaml
@@ -1,9 +1,9 @@
-title: 'Tutorial 2: Activation Functions'
+title: "Tutorial 2: Activation Functions"
author: Phillip Lippe
created: 2021-08-27
updated: 2023-03-14
license: CC BY-SA
-description: 'In this tutorial, we will take a closer look at (popular) activation
+description: "In this tutorial, we will take a closer look at (popular) activation
functions and investigate their effect on optimization properties in neural networks.
Activation functions are a crucial part of deep learning models as they add the
@@ -13,32 +13,32 @@ description: 'In this tutorial, we will take a closer look at (popular) activati
more beneficial than others.
The goal of this tutorial is to show the importance of choosing a good activation
- function (and how to do so), and what problems might occur if we don''t.
+ function (and how to do so), and what problems might occur if we don't.
This notebook is part of a lecture series on Deep Learning at the University of
Amsterdam.
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- ipython==8.16.1
-- numpy==1.26.4
-- torchmetrics==1.2.1
-- torch==2.0.1
-- pytorch-lightning==2.0.9.post0
-- torchvision==0.15.2
-- lightning==2.3.3
-- seaborn==0.13.2
-- setuptools==69.0.3
-- matplotlib==3.8.4
-- urllib3==2.2.2
-published: '2024-07-19T19:22:02.189820'
+ - ipython==8.16.1
+ - numpy==1.26.4
+ - torchmetrics==1.2.1
+ - torch==2.0.1
+ - pytorch-lightning==2.0.9.post0
+ - torchvision==0.15.2
+ - lightning==2.3.3
+ - seaborn==0.13.2
+ - setuptools==69.0.3
+ - matplotlib==3.8.4
+ - urllib3==2.2.2
+published: "2024-07-19T19:22:02.189820"
diff --git a/.notebooks/course_UvA-DL/03-initialization-and-optimization.yaml b/.notebooks/course_UvA-DL/03-initialization-and-optimization.yaml
index e0c6f35dd..71f487fc7 100644
--- a/.notebooks/course_UvA-DL/03-initialization-and-optimization.yaml
+++ b/.notebooks/course_UvA-DL/03-initialization-and-optimization.yaml
@@ -1,13 +1,13 @@
-title: 'Tutorial 3: Initialization and Optimization'
+title: "Tutorial 3: Initialization and Optimization"
author: Phillip Lippe
created: 2021-08-27
updated: 2023-03-14
license: CC BY-SA
tags:
-- Image
-- Initialization
-- Optimizers
-description: 'In this tutorial, we will review techniques for optimization and initialization
+ - Image
+ - Initialization
+ - Optimizers
+description: "In this tutorial, we will review techniques for optimization and initialization
of neural networks.
When increasing the depth of neural networks, there are various challenges we face.
@@ -23,25 +23,25 @@ description: 'In this tutorial, we will review techniques for optimization and i
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- ipython==8.16.1
-- urllib3==2.2.2
-- seaborn==0.13.2
-- numpy==1.26.4
-- torchmetrics==1.2.1
-- torch==2.0.1
-- pytorch-lightning==2.0.9.post0
-- lightning==2.3.3
-- matplotlib==3.8.4
-- setuptools==69.0.3
-- torchvision==0.15.2
-published: '2024-07-19T19:29:56.143678'
+ - ipython==8.16.1
+ - urllib3==2.2.2
+ - seaborn==0.13.2
+ - numpy==1.26.4
+ - torchmetrics==1.2.1
+ - torch==2.0.1
+ - pytorch-lightning==2.0.9.post0
+ - lightning==2.3.3
+ - matplotlib==3.8.4
+ - setuptools==69.0.3
+ - torchvision==0.15.2
+published: "2024-07-19T19:29:56.143678"
diff --git a/.notebooks/course_UvA-DL/04-inception-resnet-densenet.yaml b/.notebooks/course_UvA-DL/04-inception-resnet-densenet.yaml
index 5a8b3d092..141fa9a6d 100644
--- a/.notebooks/course_UvA-DL/04-inception-resnet-densenet.yaml
+++ b/.notebooks/course_UvA-DL/04-inception-resnet-densenet.yaml
@@ -1,11 +1,11 @@
-title: 'Tutorial 4: Inception, ResNet and DenseNet'
+title: "Tutorial 4: Inception, ResNet and DenseNet"
author: Phillip Lippe
created: 2021-08-27
updated: 2023-03-14
license: CC BY-SA
tags:
-- Image
-description: 'In this tutorial, we will implement and discuss variants of modern CNN
+ - Image
+description: "In this tutorial, we will implement and discuss variants of modern CNN
architectures.
There have been many different architectures been proposed over the past few years.
@@ -26,28 +26,28 @@ description: 'In this tutorial, we will implement and discuss variants of modern
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- tabulate
-- lightning>=2.0.0
-- tensorboard
+ - torchvision
+ - matplotlib
+ - seaborn
+ - tabulate
+ - lightning>=2.0.0
+ - tensorboard
accelerator:
-- GPU
+ - GPU
environment:
-- setuptools==69.0.3
-- lightning==2.3.3
-- seaborn==0.13.2
-- urllib3==2.2.2
-- tensorboard==2.17.0
-- torchmetrics==1.2.1
-- numpy==1.26.4
-- ipython==8.16.1
-- torchvision==0.15.2
-- tabulate==0.9.0
-- pytorch-lightning==2.0.9.post0
-- matplotlib==3.8.4
-- torch==2.0.1
-published: '2024-07-19T19:35:09.316261'
+ - setuptools==69.0.3
+ - lightning==2.3.3
+ - seaborn==0.13.2
+ - urllib3==2.2.2
+ - tensorboard==2.17.0
+ - torchmetrics==1.2.1
+ - numpy==1.26.4
+ - ipython==8.16.1
+ - torchvision==0.15.2
+ - tabulate==0.9.0
+ - pytorch-lightning==2.0.9.post0
+ - matplotlib==3.8.4
+ - torch==2.0.1
+published: "2024-07-19T19:35:09.316261"
diff --git a/.notebooks/course_UvA-DL/05-transformers-and-MH-attention.yaml b/.notebooks/course_UvA-DL/05-transformers-and-MH-attention.yaml
index 47ae53a9a..8589280a4 100644
--- a/.notebooks/course_UvA-DL/05-transformers-and-MH-attention.yaml
+++ b/.notebooks/course_UvA-DL/05-transformers-and-MH-attention.yaml
@@ -1,12 +1,12 @@
-title: 'Tutorial 5: Transformers and Multi-Head Attention'
+title: "Tutorial 5: Transformers and Multi-Head Attention"
author: Phillip Lippe
created: 2021-06-30
updated: 2023-03-14
license: CC BY-SA
build: 0
tags:
-- Text
-description: 'In this tutorial, we will discuss one of the most impactful architectures
+ - Text
+description: "In this tutorial, we will discuss one of the most impactful architectures
of the last 2 years: the Transformer model.
Since the paper Attention Is All You Need by Vaswani et al. had been published in
@@ -29,24 +29,24 @@ description: 'In this tutorial, we will discuss one of the most impactful archit
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
accelerator:
-- GPU
+ - GPU
environment:
-- pytorch-lightning==2.0.9.post0
-- lightning==2.3.3
-- setuptools==69.0.3
-- ipython==8.16.1
-- numpy==1.26.4
-- torchvision==0.15.2
-- seaborn==0.13.2
-- torchmetrics==1.2.1
-- matplotlib==3.8.4
-- torch==2.0.1
-- urllib3==2.2.2
-published: '2024-07-19T19:40:50.645247'
+ - pytorch-lightning==2.0.9.post0
+ - lightning==2.3.3
+ - setuptools==69.0.3
+ - ipython==8.16.1
+ - numpy==1.26.4
+ - torchvision==0.15.2
+ - seaborn==0.13.2
+ - torchmetrics==1.2.1
+ - matplotlib==3.8.4
+ - torch==2.0.1
+ - urllib3==2.2.2
+published: "2024-07-19T19:40:50.645247"
diff --git a/.notebooks/course_UvA-DL/06-graph-neural-networks.ipynb b/.notebooks/course_UvA-DL/06-graph-neural-networks.ipynb
new file mode 100644
index 000000000..c39b63c4e
--- /dev/null
+++ b/.notebooks/course_UvA-DL/06-graph-neural-networks.ipynb
@@ -0,0 +1,2702 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "4f979e51",
+ "metadata": {
+ "papermill": {
+ "duration": 0.014037,
+ "end_time": "2023-10-11T16:03:06.630612",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:06.616575",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "\n",
+ "# Tutorial 6: Basics of Graph Neural Networks\n",
+ "\n",
+ "* **Author:** Phillip Lippe\n",
+ "* **License:** CC BY-SA\n",
+ "* **Generated:** 2023-10-11T16:02:31.112587\n",
+ "\n",
+ "In this tutorial, we will discuss the application of neural networks on graphs.\n",
+ "Graph Neural Networks (GNNs) have recently gained increasing popularity in both applications and research,\n",
+ "including domains such as social networks, knowledge graphs, recommender systems, and bioinformatics.\n",
+ "While the theory and math behind GNNs might first seem complicated,\n",
+ "the implementation of those models is quite simple and helps in understanding the methodology.\n",
+ "Therefore, we will discuss the implementation of basic network layers of a GNN,\n",
+ "namely graph convolutions, and attention layers.\n",
+ "Finally, we will apply a GNN on semi-supervised node classification and molecule categorization.\n",
+ "This notebook is part of a lecture series on Deep Learning at the University of Amsterdam.\n",
+ "The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.\n",
+ "\n",
+ "\n",
+ "---\n",
+ "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/course_UvA-DL/06-graph-neural-networks.ipynb)\n",
+ "\n",
+ "Give us a ⭐ [on Github](https://www.github.com/Lightning-AI/lightning/)\n",
+ "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n",
+ "| Join us [on Slack](https://www.pytorchlightning.ai/community)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5214d4fa",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012019,
+ "end_time": "2023-10-11T16:03:06.661853",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:06.649834",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Setup\n",
+ "This notebook requires some packages besides pytorch-lightning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "1c5351a8",
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:06.687521Z",
+ "iopub.status.busy": "2023-10-11T16:03:06.687127Z",
+ "iopub.status.idle": "2023-10-11T16:03:10.716383Z",
+ "shell.execute_reply": "2023-10-11T16:03:10.715453Z"
+ },
+ "id": "LfrJLKPFyhsK",
+ "lines_to_next_cell": 0,
+ "papermill": {
+ "duration": 4.044652,
+ "end_time": "2023-10-11T16:03:10.718369",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:06.673717",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "! pip install --quiet \"torch-geometric\" \"ipython[notebook]>=8.0.0, <8.17.0\" \"lightning>=2.0.0\" \"torch-sparse\" \"torch-cluster\" \"torch-scatter\" \"torch-spline-conv\" \"pytorch-lightning>=1.4, <2.1.0\" \"torchmetrics>=0.7, <1.3\" \"setuptools>=68.0.0, <68.3.0\" \"matplotlib>=3.0.0, <3.9.0\" \"torch>=1.8.1, <2.1.0\" \"urllib3\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "648a5e0d",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009288,
+ "end_time": "2023-10-11T16:03:10.737471",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:10.728183",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "
\n",
+ "We start by importing our standard libraries below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "c7750212",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:10.757524Z",
+ "iopub.status.busy": "2023-10-11T16:03:10.756995Z",
+ "iopub.status.idle": "2023-10-11T16:03:15.530236Z",
+ "shell.execute_reply": "2023-10-11T16:03:15.529272Z"
+ },
+ "papermill": {
+ "duration": 4.790934,
+ "end_time": "2023-10-11T16:03:15.537517",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:10.746583",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 42\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Standard libraries\n",
+ "import os\n",
+ "\n",
+ "# For downloading pre-trained models\n",
+ "import urllib.request\n",
+ "from urllib.error import HTTPError\n",
+ "\n",
+ "# PyTorch Lightning\n",
+ "import lightning as L\n",
+ "\n",
+ "# PyTorch\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torch.optim as optim\n",
+ "\n",
+ "# PyTorch geometric\n",
+ "import torch_geometric\n",
+ "import torch_geometric.data as geom_data\n",
+ "import torch_geometric.nn as geom_nn\n",
+ "\n",
+ "# PL callbacks\n",
+ "from lightning.pytorch.callbacks import ModelCheckpoint\n",
+ "from torch import Tensor\n",
+ "\n",
+ "AVAIL_GPUS = min(1, torch.cuda.device_count())\n",
+ "BATCH_SIZE = 256 if AVAIL_GPUS else 64\n",
+ "# Path to the folder where the datasets are/should be downloaded\n",
+ "DATASET_PATH = os.environ.get(\"PATH_DATASETS\", \"data/\")\n",
+ "# Path to the folder where the pretrained models are saved\n",
+ "CHECKPOINT_PATH = os.environ.get(\"PATH_CHECKPOINT\", \"saved_models/GNNs/\")\n",
+ "\n",
+ "# Setting the seed\n",
+ "L.seed_everything(42)\n",
+ "\n",
+ "# Ensure that all operations are deterministic on GPU (if used) for reproducibility\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "32a3eca2",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009097,
+ "end_time": "2023-10-11T16:03:15.557823",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:15.548726",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "We also have a few pre-trained models we download below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "2a6a3f6a",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:15.582822Z",
+ "iopub.status.busy": "2023-10-11T16:03:15.581987Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.134259Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.133238Z"
+ },
+ "papermill": {
+ "duration": 0.564879,
+ "end_time": "2023-10-11T16:03:16.136013",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:15.571134",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelMLP.ckpt...\n",
+ "Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelGNN.ckpt...\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/GraphLevelGraphConv.ckpt...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Github URL where saved models are stored for this tutorial\n",
+ "base_url = \"https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/\"\n",
+ "# Files to download\n",
+ "pretrained_files = [\"NodeLevelMLP.ckpt\", \"NodeLevelGNN.ckpt\", \"GraphLevelGraphConv.ckpt\"]\n",
+ "\n",
+ "# Create checkpoint path if it doesn't exist yet\n",
+ "os.makedirs(CHECKPOINT_PATH, exist_ok=True)\n",
+ "\n",
+ "# For each file, check whether it already exists. If not, try downloading it.\n",
+ "for file_name in pretrained_files:\n",
+ " file_path = os.path.join(CHECKPOINT_PATH, file_name)\n",
+ " if \"/\" in file_name:\n",
+ " os.makedirs(file_path.rsplit(\"/\", 1)[0], exist_ok=True)\n",
+ " if not os.path.isfile(file_path):\n",
+ " file_url = base_url + file_name\n",
+ " print(\"Downloading %s...\" % file_url)\n",
+ " try:\n",
+ " urllib.request.urlretrieve(file_url, file_path)\n",
+ " except HTTPError as e:\n",
+ " print(\n",
+ " \"Something went wrong. Please try to download the file from the GDrive folder,\"\n",
+ " \" or contact the author with the full output including the following error:\\n\",\n",
+ " e,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23531974",
+ "metadata": {
+ "papermill": {
+ "duration": 0.04981,
+ "end_time": "2023-10-11T16:03:16.195921",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.146111",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Graph Neural Networks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "123ed1b4",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009251,
+ "end_time": "2023-10-11T16:03:16.214598",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.205347",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Graph representation\n",
+ "\n",
+ "Before starting the discussion of specific neural network operations on graphs, we should consider how to represent a graph.\n",
+ "Mathematically, a graph $\\mathcal{G}$ is defined as a tuple of a set of nodes/vertices $V$, and a set of edges/links $E$: $\\mathcal{G}=(V,E)$.\n",
+ "Each edge is a pair of two vertices, and represents a connection between them.\n",
+ "For instance, let's look at the following graph:\n",
+ "\n",
+ "
\n",
+ "\n",
+ "The vertices are $V=\\{1,2,3,4\\}$, and edges $E=\\{(1,2), (2,3), (2,4), (3,4)\\}$.\n",
+ "Note that for simplicity, we assume the graph to be undirected and hence don't add mirrored pairs like $(2,1)$.\n",
+ "In application, vertices and edge can often have specific attributes, and edges can even be directed.\n",
+ "The question is how we could represent this diversity in an efficient way for matrix operations.\n",
+ "Usually, for the edges, we decide between two variants: an adjacency matrix, or a list of paired vertex indices.\n",
+ "\n",
+ "The **adjacency matrix** $A$ is a square matrix whose elements indicate whether pairs of vertices are adjacent,\n",
+ "i.e. connected, or not.\n",
+ "In the simplest case, $A_{ij}$ is 1 if there is a connection from node $i$ to $j$, and otherwise 0.\n",
+ "If we have edge attributes or different categories of edges in a graph, this information can be added to the matrix as well.\n",
+ "For an undirected graph, keep in mind that $A$ is a symmetric matrix ($A_{ij}=A_{ji}$).\n",
+ "For the example graph above, we have the following adjacency matrix:\n",
+ "\n",
+ "$$\n",
+ "A = \\begin{bmatrix}\n",
+ " 0 & 1 & 0 & 0\\\\\n",
+ " 1 & 0 & 1 & 1\\\\\n",
+ " 0 & 1 & 0 & 1\\\\\n",
+ " 0 & 1 & 1 & 0\n",
+ "\\end{bmatrix}\n",
+ "$$\n",
+ "\n",
+ "While expressing a graph as a list of edges is more efficient in terms of memory and (possibly) computation,\n",
+ "using an adjacency matrix is more intuitive and simpler to implement.\n",
+ "In our implementations below, we will rely on the adjacency matrix to keep the code simple.\n",
+ "However, common libraries use edge lists, which we will discuss later more.\n",
+ "Alternatively, we could also use the list of edges to define a sparse adjacency matrix with which we can work\n",
+ "as if it was a dense matrix, but allows more memory-efficient operations.\n",
+ "PyTorch supports this with the sub-package `torch.sparse`\n",
+ "([documentation](https://pytorch.org/docs/stable/sparse.html)) which is however still in a beta-stage\n",
+ "(API might change in future)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d28b1897",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.012645,
+ "end_time": "2023-10-11T16:03:16.236437",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.223792",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Graph Convolutions\n",
+ "\n",
+ "Graph Convolutional Networks have been introduced by [Kipf et al. ](https://openreview.net/pdf?id=SJU4ayYgl)\n",
+ "in 2016 at the University of Amsterdam.\n",
+ "He also wrote a great [blog post](https://tkipf.github.io/graph-convolutional-networks/) about this topic,\n",
+ "which is recommended if you want to read about GCNs from a different perspective.\n",
+ "GCNs are similar to convolutions in images in the sense that the \"filter\" parameters are typically shared over all locations in the graph.\n",
+ "At the same time, GCNs rely on message passing methods, which means that vertices exchange information with the neighbors,\n",
+ "and send \"messages\" to each other.\n",
+ "Before looking at the math, we can try to visually understand how GCNs work.\n",
+ "The first step is that each node creates a feature vector that represents the message it wants to send to all its neighbors.\n",
+ "In the second step, the messages are sent to the neighbors, so that a node receives one message per adjacent node.\n",
+ "Below we have visualized the two steps for our example graph.\n",
+ "\n",
+ "
\n",
+ "\n",
+ "If we want to formulate that in more mathematical terms, we need to first decide how to combine\n",
+ "all the messages a node receives.\n",
+ "As the number of messages vary across nodes, we need an operation that works for any number.\n",
+ "Hence, the usual way to go is to sum or take the mean.\n",
+ "Given the previous features of nodes $H^{(l)}$, the GCN layer is defined as follows:\n",
+ "\n",
+ "$$H^{(l+1)} = \\sigma\\left(\\hat{D}^{-1/2}\\hat{A}\\hat{D}^{-1/2}H^{(l)}W^{(l)}\\right)$$\n",
+ "\n",
+ "$W^{(l)}$ is the weight parameters with which we transform the input features into messages ($H^{(l)}W^{(l)}$).\n",
+ "To the adjacency matrix $A$ we add the identity matrix so that each node sends its own message also to itself:\n",
+ "$\\hat{A}=A+I$.\n",
+ "Finally, to take the average instead of summing, we calculate the matrix $\\hat{D}$ which is a diagonal\n",
+ "matrix with $D_{ii}$ denoting the number of neighbors node $i$ has.\n",
+ "$\\sigma$ represents an arbitrary activation function, and not necessarily the sigmoid (usually a ReLU-based\n",
+ "activation function is used in GNNs).\n",
+ "\n",
+ "When implementing the GCN layer in PyTorch, we can take advantage of the flexible operations on tensors.\n",
+ "Instead of defining a matrix $\\hat{D}$, we can simply divide the summed messages by the number of neighbors afterward.\n",
+ "Additionally, we replace the weight matrix with a linear layer, which additionally allows us to add a bias.\n",
+ "Written as a PyTorch module, the GCN layer is defined as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6ce21fc3",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.257386Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.256846Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.268948Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.267658Z"
+ },
+ "papermill": {
+ "duration": 0.025196,
+ "end_time": "2023-10-11T16:03:16.271397",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.246201",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class GCNLayer(nn.Module):\n",
+ " def __init__(self, c_in, c_out):\n",
+ " super().__init__()\n",
+ " self.projection = nn.Linear(c_in, c_out)\n",
+ "\n",
+ " def forward(self, node_feats, adj_matrix):\n",
+ " \"\"\"Forward.\n",
+ "\n",
+ " Args:\n",
+ " node_feats: Tensor with node features of shape [batch_size, num_nodes, c_in]\n",
+ " adj_matrix: Batch of adjacency matrices of the graph. If there is an edge from i to j,\n",
+ " adj_matrix[b,i,j]=1 else 0. Supports directed edges by non-symmetric matrices.\n",
+ " Assumes to already have added the identity connections.\n",
+ " Shape: [batch_size, num_nodes, num_nodes]\n",
+ " \"\"\"\n",
+ " # Num neighbours = number of incoming edges\n",
+ " num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)\n",
+ " node_feats = self.projection(node_feats)\n",
+ " node_feats = torch.bmm(adj_matrix, node_feats)\n",
+ " node_feats = node_feats / num_neighbours\n",
+ " return node_feats"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f465968d",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009388,
+ "end_time": "2023-10-11T16:03:16.290193",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.280805",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "To further understand the GCN layer, we can apply it to our example graph above.\n",
+ "First, let's specify some node features and the adjacency matrix with added self-connections:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ae773b51",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.310470Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.309858Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.324044Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.323176Z"
+ },
+ "papermill": {
+ "duration": 0.026083,
+ "end_time": "2023-10-11T16:03:16.325522",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.299439",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Node features:\n",
+ " tensor([[[0., 1.],\n",
+ " [2., 3.],\n",
+ " [4., 5.],\n",
+ " [6., 7.]]])\n",
+ "\n",
+ "Adjacency matrix:\n",
+ " tensor([[[1., 1., 0., 0.],\n",
+ " [1., 1., 1., 1.],\n",
+ " [0., 1., 1., 1.],\n",
+ " [0., 1., 1., 1.]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)\n",
+ "adj_matrix = Tensor([[[1, 1, 0, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 1]]])\n",
+ "\n",
+ "print(\"Node features:\\n\", node_feats)\n",
+ "print(\"\\nAdjacency matrix:\\n\", adj_matrix)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "85a343c9",
+ "metadata": {
+ "papermill": {
+ "duration": 0.015853,
+ "end_time": "2023-10-11T16:03:16.350759",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.334906",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Next, let's apply a GCN layer to it.\n",
+ "For simplicity, we initialize the linear weight matrix as an identity matrix so that the input features are equal to the messages.\n",
+ "This makes it easier for us to verify the message passing operation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "f3352c18",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.373984Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.373472Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.381312Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.380405Z"
+ },
+ "papermill": {
+ "duration": 0.020795,
+ "end_time": "2023-10-11T16:03:16.382866",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.362071",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Adjacency matrix tensor([[[1., 1., 0., 0.],\n",
+ " [1., 1., 1., 1.],\n",
+ " [0., 1., 1., 1.],\n",
+ " [0., 1., 1., 1.]]])\n",
+ "Input features tensor([[[0., 1.],\n",
+ " [2., 3.],\n",
+ " [4., 5.],\n",
+ " [6., 7.]]])\n",
+ "Output features tensor([[[1., 2.],\n",
+ " [3., 4.],\n",
+ " [4., 5.],\n",
+ " [4., 5.]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "layer = GCNLayer(c_in=2, c_out=2)\n",
+ "layer.projection.weight.data = Tensor([[1.0, 0.0], [0.0, 1.0]])\n",
+ "layer.projection.bias.data = Tensor([0.0, 0.0])\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " out_feats = layer(node_feats, adj_matrix)\n",
+ "\n",
+ "print(\"Adjacency matrix\", adj_matrix)\n",
+ "print(\"Input features\", node_feats)\n",
+ "print(\"Output features\", out_feats)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3556a93a",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009481,
+ "end_time": "2023-10-11T16:03:16.401826",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.392345",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "As we can see, the first node's output values are the average of itself and the second node.\n",
+ "Similarly, we can verify all other nodes.\n",
+ "However, in a GNN, we would also want to allow feature exchange between nodes beyond its neighbors.\n",
+ "This can be achieved by applying multiple GCN layers, which gives us the final layout of a GNN.\n",
+ "The GNN can be build up by a sequence of GCN layers and non-linearities such as ReLU.\n",
+ "For a visualization, see below (figure credit - [Thomas Kipf, 2016](https://tkipf.github.io/graph-convolutional-networks/)).\n",
+ "\n",
+ "
\n",
+ "\n",
+ "However, one issue we can see from looking at the example above is that the output features for nodes 3 and 4 are\n",
+ "the same because they have the same adjacent nodes (including itself).\n",
+ "Therefore, GCN layers can make the network forget node-specific information if we just take a mean over all messages.\n",
+ "Multiple possible improvements have been proposed.\n",
+ "While the simplest option might be using residual connections, the more common approach is to either weigh\n",
+ "the self-connections higher or define a separate weight matrix for the self-connections.\n",
+ "Alternatively, we can use a well-known concept: attention."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "43d1a129",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.009353,
+ "end_time": "2023-10-11T16:03:16.420575",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.411222",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Graph Attention\n",
+ "\n",
+ "Attention describes a weighted average of multiple elements with the weights dynamically computed based on an input\n",
+ "query and elements' keys (if you don't know what attention is, it is recommended to at least go through\n",
+ "the very first section called [What is Attention?](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html#What-is-Attention?)).\n",
+ "This concept can be similarly applied to graphs, one of such is the Graph Attention Network\n",
+ "(called GAT, proposed by [Velickovic et al., 2017](https://arxiv.org/abs/1710.10903)).\n",
+ "Similarly to the GCN, the graph attention layer creates a message for each node using a linear layer/weight matrix.\n",
+ "For the attention part, it uses the message from the node itself as a query, and the messages to average as both\n",
+ "keys and values (note that this also includes the message to itself).\n",
+ "The score function $f_{attn}$ is implemented as a one-layer MLP which maps the query and key to a single value.\n",
+ "The MLP looks as follows (figure credit - [Velickovic et al. ](https://arxiv.org/abs/1710.10903)):\n",
+ "\n",
+ "
\n",
+ "\n",
+ "$h_i$ and $h_j$ are the original features from node $i$ and $j$ respectively, and represent the messages\n",
+ "of the layer with $\\mathbf{W}$ as weight matrix.\n",
+ "$\\mathbf{a}$ is the weight matrix of the MLP, which has the shape $[1,2\\times d_{\\text{message}}]$,\n",
+ "and $\\alpha_{ij}$ the final attention weight from node $i$ to $j$.\n",
+ "The calculation can be described as follows:\n",
+ "\n",
+ "$$\\alpha_{ij} = \\frac{\\exp\\left(\\text{LeakyReLU}\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_j\\right]\\right)\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\text{LeakyReLU}\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_k\\right]\\right)\\right)}$$\n",
+ "\n",
+ "The operator $||$ represents the concatenation, and $\\mathcal{N}_i$ the indices of the neighbors of node $i$.\n",
+ "Note that in contrast to usual practice, we apply a non-linearity (here LeakyReLU) before the softmax over elements.\n",
+ "Although it seems like a minor change at first, it is crucial for the attention to depend on the original input.\n",
+ "Specifically, let's remove the non-linearity for a second, and try to simplify the expression:\n",
+ "\n",
+ "$$\n",
+ "\\begin{split}\n",
+ " \\alpha_{ij} & = \\frac{\\exp\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_j\\right]\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_k\\right]\\right)}\\\\[5pt]\n",
+ " & = \\frac{\\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i+\\mathbf{a}_{:,d/2:}\\mathbf{W}h_j\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i+\\mathbf{a}_{:,d/2:}\\mathbf{W}h_k\\right)}\\\\[5pt]\n",
+ " & = \\frac{\\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i\\right)\\cdot\\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_j\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i\\right)\\cdot\\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_k\\right)}\\\\[5pt]\n",
+ " & = \\frac{\\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_j\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_k\\right)}\\\\\n",
+ "\\end{split}\n",
+ "$$\n",
+ "\n",
+ "We can see that without the non-linearity, the attention term with $h_i$ actually cancels itself out,\n",
+ "resulting in the attention being independent of the node itself.\n",
+ "Hence, we would have the same issue as the GCN of creating the same output features for nodes with the same neighbors.\n",
+ "This is why the LeakyReLU is crucial and adds some dependency on $h_i$ to the attention.\n",
+ "\n",
+ "Once we obtain all attention factors, we can calculate the output features for each node by performing\n",
+ "the weighted average:\n",
+ "\n",
+ "$$h_i'=\\sigma\\left(\\sum_{j\\in\\mathcal{N}_i}\\alpha_{ij}\\mathbf{W}h_j\\right)$$\n",
+ "\n",
+ "$\\sigma$ is yet another non-linearity, as in the GCN layer.\n",
+ "Visually, we can represent the full message passing in an attention layer as follows\n",
+ "(figure credit - [Velickovic et al. ](https://arxiv.org/abs/1710.10903)):\n",
+ "\n",
+ "
\n",
+ "\n",
+ "To increase the expressiveness of the graph attention network, [Velickovic et al. ](https://arxiv.org/abs/1710.10903)\n",
+ "proposed to extend it to multiple heads similar to the Multi-Head Attention block in Transformers.\n",
+ "This results in $N$ attention layers being applied in parallel.\n",
+ "In the image above, it is visualized as three different colors of arrows (green, blue, and purple)\n",
+ "that are afterward concatenated.\n",
+ "The average is only applied for the very final prediction layer in a network.\n",
+ "\n",
+ "After having discussed the graph attention layer in detail, we can implement it below:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "7c3c6d13",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.441063Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.440489Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.456587Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.455511Z"
+ },
+ "papermill": {
+ "duration": 0.028039,
+ "end_time": "2023-10-11T16:03:16.458052",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.430013",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class GATLayer(nn.Module):\n",
+ " def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " c_in: Dimensionality of input features\n",
+ " c_out: Dimensionality of output features\n",
+ " num_heads: Number of heads, i.e. attention mechanisms to apply in parallel. The\n",
+ " output features are equally split up over the heads if concat_heads=True.\n",
+ " concat_heads: If True, the output of the different heads is concatenated instead of averaged.\n",
+ " alpha: Negative slope of the LeakyReLU activation.\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " self.num_heads = num_heads\n",
+ " self.concat_heads = concat_heads\n",
+ " if self.concat_heads:\n",
+ " assert c_out % num_heads == 0, \"Number of output features must be a multiple of the count of heads.\"\n",
+ " c_out = c_out // num_heads\n",
+ "\n",
+ " # Sub-modules and parameters needed in the layer\n",
+ " self.projection = nn.Linear(c_in, c_out * num_heads)\n",
+ " self.a = nn.Parameter(Tensor(num_heads, 2 * c_out)) # One per head\n",
+ " self.leakyrelu = nn.LeakyReLU(alpha)\n",
+ "\n",
+ " # Initialization from the original implementation\n",
+ " nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)\n",
+ " nn.init.xavier_uniform_(self.a.data, gain=1.414)\n",
+ "\n",
+ " def forward(self, node_feats, adj_matrix, print_attn_probs=False):\n",
+ " \"\"\"Forward.\n",
+ "\n",
+ " Args:\n",
+ " node_feats: Input features of the node. Shape: [batch_size, c_in]\n",
+ " adj_matrix: Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]\n",
+ " print_attn_probs: If True, the attention weights are printed during the forward pass\n",
+ " (for debugging purposes)\n",
+ " \"\"\"\n",
+ " batch_size, num_nodes = node_feats.size(0), node_feats.size(1)\n",
+ "\n",
+ " # Apply linear layer and sort nodes by head\n",
+ " node_feats = self.projection(node_feats)\n",
+ " node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)\n",
+ "\n",
+ " # We need to calculate the attention logits for every edge in the adjacency matrix\n",
+ " # Doing this on all possible combinations of nodes is very expensive\n",
+ " # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges\n",
+ " # Returns indices where the adjacency matrix is not 0 => edges\n",
+ " edges = adj_matrix.nonzero(as_tuple=False)\n",
+ " node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)\n",
+ " edge_indices_row = edges[:, 0] * num_nodes + edges[:, 1]\n",
+ " edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]\n",
+ " a_input = torch.cat(\n",
+ " [\n",
+ " torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),\n",
+ " torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0),\n",
+ " ],\n",
+ " dim=-1,\n",
+ " ) # Index select returns a tensor with node_feats_flat being indexed at the desired positions\n",
+ "\n",
+ " # Calculate attention MLP output (independent for each head)\n",
+ " attn_logits = torch.einsum(\"bhc,hc->bh\", a_input, self.a)\n",
+ " attn_logits = self.leakyrelu(attn_logits)\n",
+ "\n",
+ " # Map list of attention values back into a matrix\n",
+ " attn_matrix = attn_logits.new_zeros(adj_matrix.shape + (self.num_heads,)).fill_(-9e15)\n",
+ " attn_matrix[adj_matrix[..., None].repeat(1, 1, 1, self.num_heads) == 1] = attn_logits.reshape(-1)\n",
+ "\n",
+ " # Weighted average of attention\n",
+ " attn_probs = F.softmax(attn_matrix, dim=2)\n",
+ " if print_attn_probs:\n",
+ " print(\"Attention probs\\n\", attn_probs.permute(0, 3, 1, 2))\n",
+ " node_feats = torch.einsum(\"bijh,bjhc->bihc\", attn_probs, node_feats)\n",
+ "\n",
+ " # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean\n",
+ " if self.concat_heads:\n",
+ " node_feats = node_feats.reshape(batch_size, num_nodes, -1)\n",
+ " else:\n",
+ " node_feats = node_feats.mean(dim=2)\n",
+ "\n",
+ " return node_feats"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf5f5993",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009437,
+ "end_time": "2023-10-11T16:03:16.477084",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.467647",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Again, we can apply the graph attention layer on our example graph above to understand the dynamics better.\n",
+ "As before, the input layer is initialized as an identity matrix, but we set $\\mathbf{a}$\n",
+ "to be a vector of arbitrary numbers to obtain different attention values.\n",
+ "We use two heads to show the parallel, independent attention mechanisms working in the layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "4d348ba1",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.497519Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.496931Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.566018Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.565240Z"
+ },
+ "papermill": {
+ "duration": 0.084686,
+ "end_time": "2023-10-11T16:03:16.571182",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.486496",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Attention probs\n",
+ " tensor([[[[0.3543, 0.6457, 0.0000, 0.0000],\n",
+ " [0.1096, 0.1450, 0.2642, 0.4813],\n",
+ " [0.0000, 0.1858, 0.2885, 0.5257],\n",
+ " [0.0000, 0.2391, 0.2696, 0.4913]],\n",
+ "\n",
+ " [[0.5100, 0.4900, 0.0000, 0.0000],\n",
+ " [0.2975, 0.2436, 0.2340, 0.2249],\n",
+ " [0.0000, 0.3838, 0.3142, 0.3019],\n",
+ " [0.0000, 0.4018, 0.3289, 0.2693]]]])\n",
+ "Adjacency matrix tensor([[[1., 1., 0., 0.],\n",
+ " [1., 1., 1., 1.],\n",
+ " [0., 1., 1., 1.],\n",
+ " [0., 1., 1., 1.]]])\n",
+ "Input features tensor([[[0., 1.],\n",
+ " [2., 3.],\n",
+ " [4., 5.],\n",
+ " [6., 7.]]])\n",
+ "Output features tensor([[[1.2913, 1.9800],\n",
+ " [4.2344, 3.7725],\n",
+ " [4.6798, 4.8362],\n",
+ " [4.5043, 4.7351]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "layer = GATLayer(2, 2, num_heads=2)\n",
+ "layer.projection.weight.data = Tensor([[1.0, 0.0], [0.0, 1.0]])\n",
+ "layer.projection.bias.data = Tensor([0.0, 0.0])\n",
+ "layer.a.data = Tensor([[-0.2, 0.3], [0.1, -0.1]])\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " out_feats = layer(node_feats, adj_matrix, print_attn_probs=True)\n",
+ "\n",
+ "print(\"Adjacency matrix\", adj_matrix)\n",
+ "print(\"Input features\", node_feats)\n",
+ "print(\"Output features\", out_feats)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2ab15650",
+ "metadata": {
+ "papermill": {
+ "duration": 0.015782,
+ "end_time": "2023-10-11T16:03:16.610501",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.594719",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "We recommend that you try to calculate the attention matrix at least for one head and one node for yourself.\n",
+ "The entries are 0 where there does not exist an edge between $i$ and $j$.\n",
+ "For the others, we see a diverse set of attention probabilities.\n",
+ "Moreover, the output features of node 3 and 4 are now different although they have the same neighbors."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6d1738ad",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009688,
+ "end_time": "2023-10-11T16:03:16.636046",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.626358",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## PyTorch Geometric\n",
+ "\n",
+ "We had mentioned before that implementing graph networks with adjacency matrix is simple and straight-forward\n",
+ "but can be computationally expensive for large graphs.\n",
+ "Many real-world graphs can reach over 200k nodes, for which adjacency matrix-based implementations fail.\n",
+ "There are a lot of optimizations possible when implementing GNNs, and luckily, there exist packages that provide such layers.\n",
+ "The most popular packages for PyTorch are [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/)\n",
+ "and the [Deep Graph Library](https://www.dgl.ai/) (the latter being actually framework agnostic).\n",
+ "Which one to use depends on the project you are planning to do and personal taste.\n",
+ "In this tutorial, we will look at PyTorch Geometric as part of the PyTorch family.\n",
+ "\n",
+ "PyTorch Geometric provides us a set of common graph layers, including the GCN and GAT layer we implemented above.\n",
+ "Additionally, similar to PyTorch's torchvision, it provides the common graph datasets and transformations\n",
+ "on those to simplify training.\n",
+ "Compared to our implementation above, PyTorch Geometric uses a list of index pairs to represent the edges.\n",
+ "The details of this library will be explored further in our experiments.\n",
+ "\n",
+ "In our tasks below, we want to allow us to pick from a multitude of graph layers.\n",
+ "Thus, we define again below a dictionary to access those using a string:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "3ef60900",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.663161Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.662658Z",
+ "iopub.status.idle": "2023-10-11T16:03:16.685582Z",
+ "shell.execute_reply": "2023-10-11T16:03:16.681240Z"
+ },
+ "papermill": {
+ "duration": 0.039783,
+ "end_time": "2023-10-11T16:03:16.689352",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.649569",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "gnn_layer_by_name = {\"GCN\": geom_nn.GCNConv, \"GAT\": geom_nn.GATConv, \"GraphConv\": geom_nn.GraphConv}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "13e41e62",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009584,
+ "end_time": "2023-10-11T16:03:16.708583",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.698999",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Additionally to GCN and GAT, we added the layer `geom_nn.GraphConv`\n",
+ "([documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GraphConv)).\n",
+ "GraphConv is a GCN with a separate weight matrix for the self-connections.\n",
+ "Mathematically, this would be:\n",
+ "\n",
+ "$$\n",
+ "\\mathbf{x}_i^{(l+1)} = \\mathbf{W}^{(l + 1)}_1 \\mathbf{x}_i^{(l)} + \\mathbf{W}^{(\\ell + 1)}_2 \\sum_{j \\in \\mathcal{N}_i} \\mathbf{x}_j^{(l)}\n",
+ "$$\n",
+ "\n",
+ "In this formula, the neighbor's messages are added instead of averaged.\n",
+ "However, PyTorch Geometric provides the argument `aggr` to switch between summing, averaging, and max pooling."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "034c4dbc",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009575,
+ "end_time": "2023-10-11T16:03:16.727815",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.718240",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Experiments on graph structures\n",
+ "\n",
+ "
\n",
+ "\n",
+ "Tasks on graph-structured data can be grouped into three groups: node-level, edge-level and graph-level.\n",
+ "The different levels describe on which level we want to perform classification/regression.\n",
+ "We will discuss all three types in more detail below."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f25835f5",
+ "metadata": {
+ "papermill": {
+ "duration": 0.009577,
+ "end_time": "2023-10-11T16:03:16.747148",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.737571",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Node-level tasks: Semi-supervised node classification\n",
+ "\n",
+ "Node-level tasks have the goal to classify nodes in a graph.\n",
+ "Usually, we have given a single, large graph with >1000 nodes of which a certain amount of nodes are labeled.\n",
+ "We learn to classify those labeled examples during training and try to generalize to the unlabeled nodes.\n",
+ "\n",
+ "A popular example that we will use in this tutorial is the Cora dataset, a citation network among papers.\n",
+ "The Cora consists of 2708 scientific publications with links between each other representing\n",
+ "the citation of one paper by another.\n",
+ "The task is to classify each publication into one of seven classes.\n",
+ "Each publication is represented by a bag-of-words vector.\n",
+ "This means that we have a vector of 1433 elements for each publication, where a 1 at feature $i$ indicates\n",
+ "that the $i$-th word of a pre-defined dictionary is in the article.\n",
+ "Binary bag-of-words representations are commonly used when we need very simple encodings,\n",
+ "and already have an intuition of what words to expect in a network.\n",
+ "There exist much better approaches, but we will leave this to the NLP courses to discuss.\n",
+ "\n",
+ "We will load the dataset below:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "64e4c45d",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:16.769275Z",
+ "iopub.status.busy": "2023-10-11T16:03:16.768197Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.147101Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.146012Z"
+ },
+ "papermill": {
+ "duration": 1.39751,
+ "end_time": "2023-10-11T16:03:18.154239",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:16.756729",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
+ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processing...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "cora_dataset = torch_geometric.datasets.Planetoid(root=DATASET_PATH, name=\"Cora\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b46bad32",
+ "metadata": {
+ "papermill": {
+ "duration": 0.011189,
+ "end_time": "2023-10-11T16:03:18.180670",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.169481",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Let's look at how PyTorch Geometric represents the graph data.\n",
+ "Note that although we have a single graph, PyTorch Geometric returns a dataset for compatibility to other datasets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "065b8b71",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.210861Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.210003Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.219012Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.218544Z"
+ },
+ "papermill": {
+ "duration": 0.033004,
+ "end_time": "2023-10-11T16:03:18.228178",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.195174",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cora_dataset[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "41f9f836",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.010081,
+ "end_time": "2023-10-11T16:03:18.248995",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.238914",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "The graph is represented by a `Data` object\n",
+ "([documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data))\n",
+ "which we can access as a standard Python namespace.\n",
+ "The edge index tensor is the list of edges in the graph and contains the mirrored version of each edge for undirected graphs.\n",
+ "The `train_mask`, `val_mask`, and `test_mask` are boolean masks that indicate which nodes we should use for training,\n",
+ "validation, and testing.\n",
+ "The `x` tensor is the feature tensor of our 2708 publications, and `y` the labels for all nodes.\n",
+ "\n",
+ "After having seen the data, we can implement a simple graph neural network.\n",
+ "The GNN applies a sequence of graph layers (GCN, GAT, or GraphConv), ReLU as activation function,\n",
+ "and dropout for regularization.\n",
+ "See below for the specific implementation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "bd92f2e4",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.274053Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.273754Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.281026Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.280532Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.023402,
+ "end_time": "2023-10-11T16:03:18.282495",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.259093",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class GNNModel(nn.Module):\n",
+ " def __init__(\n",
+ " self,\n",
+ " c_in,\n",
+ " c_hidden,\n",
+ " c_out,\n",
+ " num_layers=2,\n",
+ " layer_name=\"GCN\",\n",
+ " dp_rate=0.1,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " \"\"\"GNNModel.\n",
+ "\n",
+ " Args:\n",
+ " c_in: Dimension of input features\n",
+ " c_hidden: Dimension of hidden features\n",
+ " c_out: Dimension of the output features. Usually number of classes in classification\n",
+ " num_layers: Number of \"hidden\" graph layers\n",
+ " layer_name: String of the graph layer to use\n",
+ " dp_rate: Dropout rate to apply throughout the network\n",
+ " kwargs: Additional arguments for the graph layer (e.g. number of heads for GAT)\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " gnn_layer = gnn_layer_by_name[layer_name]\n",
+ "\n",
+ " layers = []\n",
+ " in_channels, out_channels = c_in, c_hidden\n",
+ " for l_idx in range(num_layers - 1):\n",
+ " layers += [\n",
+ " gnn_layer(in_channels=in_channels, out_channels=out_channels, **kwargs),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Dropout(dp_rate),\n",
+ " ]\n",
+ " in_channels = c_hidden\n",
+ " layers += [gnn_layer(in_channels=in_channels, out_channels=c_out, **kwargs)]\n",
+ " self.layers = nn.ModuleList(layers)\n",
+ "\n",
+ " def forward(self, x, edge_index):\n",
+ " \"\"\"Forward.\n",
+ "\n",
+ " Args:\n",
+ " x: Input features per node\n",
+ " edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)\n",
+ " \"\"\"\n",
+ " for layer in self.layers:\n",
+ " # For graph layers, we need to add the \"edge_index\" tensor as additional input\n",
+ " # All PyTorch Geometric graph layer inherit the class \"MessagePassing\", hence\n",
+ " # we can simply check the class type.\n",
+ " if isinstance(layer, geom_nn.MessagePassing):\n",
+ " x = layer(x, edge_index)\n",
+ " else:\n",
+ " x = layer(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cdf52d9d",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.010266,
+ "end_time": "2023-10-11T16:03:18.302816",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.292550",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Good practice in node-level tasks is to create an MLP baseline that is applied to each node independently.\n",
+ "This way we can verify whether adding the graph information to the model indeed improves the prediction, or not.\n",
+ "It might also be that the features per node are already expressive enough to clearly point towards a specific class.\n",
+ "To check this, we implement a simple MLP below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "4877e955",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.324957Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.324275Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.330174Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.329347Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.018324,
+ "end_time": "2023-10-11T16:03:18.331570",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.313246",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class MLPModel(nn.Module):\n",
+ " def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):\n",
+ " \"\"\"MLPModel.\n",
+ "\n",
+ " Args:\n",
+ " c_in: Dimension of input features\n",
+ " c_hidden: Dimension of hidden features\n",
+ " c_out: Dimension of the output features. Usually number of classes in classification\n",
+ " num_layers: Number of hidden layers\n",
+ " dp_rate: Dropout rate to apply throughout the network\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " layers = []\n",
+ " in_channels, out_channels = c_in, c_hidden\n",
+ " for l_idx in range(num_layers - 1):\n",
+ " layers += [nn.Linear(in_channels, out_channels), nn.ReLU(inplace=True), nn.Dropout(dp_rate)]\n",
+ " in_channels = c_hidden\n",
+ " layers += [nn.Linear(in_channels, c_out)]\n",
+ " self.layers = nn.Sequential(*layers)\n",
+ "\n",
+ " def forward(self, x, *args, **kwargs):\n",
+ " \"\"\"Forward.\n",
+ "\n",
+ " Args:\n",
+ " x: Input features per node\n",
+ " \"\"\"\n",
+ " return self.layers(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "447b52a0",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.010195,
+ "end_time": "2023-10-11T16:03:18.352006",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.341811",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Finally, we can merge the models into a PyTorch Lightning module which handles the training,\n",
+ "validation, and testing for us."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "d1281945",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.377636Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.377121Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.393331Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.392651Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.02932,
+ "end_time": "2023-10-11T16:03:18.395619",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.366299",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class NodeLevelGNN(L.LightningModule):\n",
+ " def __init__(self, model_name, **model_kwargs):\n",
+ " super().__init__()\n",
+ " # Saving hyperparameters\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " if model_name == \"MLP\":\n",
+ " self.model = MLPModel(**model_kwargs)\n",
+ " else:\n",
+ " self.model = GNNModel(**model_kwargs)\n",
+ " self.loss_module = nn.CrossEntropyLoss()\n",
+ "\n",
+ " def forward(self, data, mode=\"train\"):\n",
+ " x, edge_index = data.x, data.edge_index\n",
+ " x = self.model(x, edge_index)\n",
+ "\n",
+ " # Only calculate the loss on the nodes corresponding to the mask\n",
+ " if mode == \"train\":\n",
+ " mask = data.train_mask\n",
+ " elif mode == \"val\":\n",
+ " mask = data.val_mask\n",
+ " elif mode == \"test\":\n",
+ " mask = data.test_mask\n",
+ " else:\n",
+ " assert False, \"Unknown forward mode: %s\" % mode\n",
+ "\n",
+ " loss = self.loss_module(x[mask], data.y[mask])\n",
+ " acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()\n",
+ " return loss, acc\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " # We use SGD here, but Adam works as well\n",
+ " optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)\n",
+ " return optimizer\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " loss, acc = self.forward(batch, mode=\"train\")\n",
+ " self.log(\"train_loss\", loss)\n",
+ " self.log(\"train_acc\", acc)\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " _, acc = self.forward(batch, mode=\"val\")\n",
+ " self.log(\"val_acc\", acc)\n",
+ "\n",
+ " def test_step(self, batch, batch_idx):\n",
+ " _, acc = self.forward(batch, mode=\"test\")\n",
+ " self.log(\"test_acc\", acc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2aa96907",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.010188,
+ "end_time": "2023-10-11T16:03:18.415853",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.405665",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Additionally to the Lightning module, we define a training function below.\n",
+ "As we have a single graph, we use a batch size of 1 for the data loader and share the same data loader for the train,\n",
+ "validation, and test set (the mask is picked inside the Lightning module).\n",
+ "Besides, we set the argument `enable_progress_bar` to False as it usually shows the progress per epoch,\n",
+ "but an epoch only consists of a single step.\n",
+ "If you have downloaded the pre-trained models in the beginning of the tutorial, we load those instead of training from scratch.\n",
+ "Finally, we test the model and return the results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "47ae5b35",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.438295Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.437639Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.445846Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.445108Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.020995,
+ "end_time": "2023-10-11T16:03:18.447153",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.426158",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def train_node_classifier(model_name, dataset, **model_kwargs):\n",
+ " L.seed_everything(42)\n",
+ " node_data_loader = geom_data.DataLoader(dataset, batch_size=1)\n",
+ "\n",
+ " # Create a PyTorch Lightning trainer\n",
+ " root_dir = os.path.join(CHECKPOINT_PATH, \"NodeLevel\" + model_name)\n",
+ " os.makedirs(root_dir, exist_ok=True)\n",
+ " trainer = L.Trainer(\n",
+ " default_root_dir=root_dir,\n",
+ " callbacks=[ModelCheckpoint(save_weights_only=True, mode=\"max\", monitor=\"val_acc\")],\n",
+ " accelerator=\"auto\",\n",
+ " devices=AVAIL_GPUS,\n",
+ " max_epochs=200,\n",
+ " enable_progress_bar=False,\n",
+ " ) # 0 because epoch size is 1\n",
+ " trainer.logger._default_hp_metric = None # Optional logging argument that we don't need\n",
+ "\n",
+ " # Check whether pretrained model exists. If yes, load it and skip training\n",
+ " pretrained_filename = os.path.join(CHECKPOINT_PATH, \"NodeLevel%s.ckpt\" % model_name)\n",
+ " if os.path.isfile(pretrained_filename):\n",
+ " print(\"Found pretrained model, loading...\")\n",
+ " model = NodeLevelGNN.load_from_checkpoint(pretrained_filename)\n",
+ " else:\n",
+ " L.seed_everything()\n",
+ " model = NodeLevelGNN(\n",
+ " model_name=model_name, c_in=dataset.num_node_features, c_out=dataset.num_classes, **model_kwargs\n",
+ " )\n",
+ " trainer.fit(model, node_data_loader, node_data_loader)\n",
+ " model = NodeLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)\n",
+ "\n",
+ " # Test best model on the test set\n",
+ " test_result = trainer.test(model, dataloaders=node_data_loader, verbose=False)\n",
+ " batch = next(iter(node_data_loader))\n",
+ " batch = batch.to(model.device)\n",
+ " _, train_acc = model.forward(batch, mode=\"train\")\n",
+ " _, val_acc = model.forward(batch, mode=\"val\")\n",
+ " result = {\"train\": train_acc, \"val\": val_acc, \"test\": test_result[0][\"test_acc\"]}\n",
+ " return model, result"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "63d4b255",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.010246,
+ "end_time": "2023-10-11T16:03:18.467747",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.457501",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Now, we can train our models. First, let's train the simple MLP:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "a871d384",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.489049Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.488856Z",
+ "iopub.status.idle": "2023-10-11T16:03:18.492689Z",
+ "shell.execute_reply": "2023-10-11T16:03:18.492181Z"
+ },
+ "papermill": {
+ "duration": 0.015812,
+ "end_time": "2023-10-11T16:03:18.493806",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.477994",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Small function for printing the test scores\n",
+ "def print_results(result_dict):\n",
+ " if \"train\" in result_dict:\n",
+ " print(\"Train accuracy: %4.2f%%\" % (100.0 * result_dict[\"train\"]))\n",
+ " if \"val\" in result_dict:\n",
+ " print(\"Val accuracy: %4.2f%%\" % (100.0 * result_dict[\"val\"]))\n",
+ " print(\"Test accuracy: %4.2f%%\" % (100.0 * result_dict[\"test\"]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "6d78fad1",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:18.515575Z",
+ "iopub.status.busy": "2023-10-11T16:03:18.515071Z",
+ "iopub.status.idle": "2023-10-11T16:03:19.423216Z",
+ "shell.execute_reply": "2023-10-11T16:03:19.422172Z"
+ },
+ "papermill": {
+ "duration": 0.920887,
+ "end_time": "2023-10-11T16:03:19.425008",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:18.504121",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 42\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
+ " warnings.warn(out)\n",
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
+ " warning_cache.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found pretrained model, loading...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/GNNs/NodeLevelMLP.ckpt`\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " rank_zero_warn(\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2708. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
+ " warning_cache.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train accuracy: 97.14%\n",
+ "Val accuracy: 54.60%\n",
+ "Test accuracy: 60.60%\n"
+ ]
+ }
+ ],
+ "source": [
+ "node_mlp_model, node_mlp_result = train_node_classifier(\n",
+ " model_name=\"MLP\", dataset=cora_dataset, c_hidden=16, num_layers=2, dp_rate=0.1\n",
+ ")\n",
+ "\n",
+ "print_results(node_mlp_result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4fda810",
+ "metadata": {
+ "papermill": {
+ "duration": 0.011495,
+ "end_time": "2023-10-11T16:03:19.448936",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:19.437441",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Although the MLP can overfit on the training dataset because of the high-dimensional input features,\n",
+ "it does not perform too well on the test set.\n",
+ "Let's see if we can beat this score with our graph networks:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "0e0fd1c4",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:19.474172Z",
+ "iopub.status.busy": "2023-10-11T16:03:19.473856Z",
+ "iopub.status.idle": "2023-10-11T16:03:20.736499Z",
+ "shell.execute_reply": "2023-10-11T16:03:20.731572Z"
+ },
+ "papermill": {
+ "duration": 1.280765,
+ "end_time": "2023-10-11T16:03:20.740917",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:19.460152",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 42\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
+ " warnings.warn(out)\n",
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/GNNs/NodeLevelGNN.ckpt`\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found pretrained model, loading...\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train accuracy: 100.00%\n",
+ "Val accuracy: 78.60%\n",
+ "Test accuracy: 82.40%\n"
+ ]
+ }
+ ],
+ "source": [
+ "node_gnn_model, node_gnn_result = train_node_classifier(\n",
+ " model_name=\"GNN\", layer_name=\"GCN\", dataset=cora_dataset, c_hidden=16, num_layers=2, dp_rate=0.1\n",
+ ")\n",
+ "print_results(node_gnn_result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fa912878",
+ "metadata": {
+ "papermill": {
+ "duration": 0.011993,
+ "end_time": "2023-10-11T16:03:20.777315",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:20.765322",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "As we would have hoped for, the GNN model outperforms the MLP by quite a margin.\n",
+ "This shows that using the graph information indeed improves our predictions and lets us generalizes better.\n",
+ "\n",
+ "The hyperparameters in the model have been chosen to create a relatively small network.\n",
+ "This is because the first layer with an input dimension of 1433 can be relatively expensive to perform for large graphs.\n",
+ "In general, GNNs can become relatively expensive for very big graphs.\n",
+ "This is why such GNNs either have a small hidden size or use a special batching strategy\n",
+ "where we sample a connected subgraph of the big, original graph."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5dcd5632",
+ "metadata": {
+ "papermill": {
+ "duration": 0.014035,
+ "end_time": "2023-10-11T16:03:20.803784",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:20.789749",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Edge-level tasks: Link prediction\n",
+ "\n",
+ "In some applications, we might have to predict on an edge-level instead of node-level.\n",
+ "The most common edge-level task in GNN is link prediction.\n",
+ "Link prediction means that given a graph, we want to predict whether there will be/should be an edge between two nodes or not.\n",
+ "For example, in a social network, this is used by Facebook and co to propose new friends to you.\n",
+ "Again, graph level information can be crucial to perform this task.\n",
+ "The output prediction is usually done by performing a similarity metric on the pair of node features,\n",
+ "which should be 1 if there should be a link, and otherwise close to 0.\n",
+ "To keep the tutorial short, we will not implement this task ourselves.\n",
+ "Nevertheless, there are many good resources out there if you are interested in looking closer at this task.\n",
+ "Tutorials and papers for this topic include:\n",
+ "\n",
+ "* [PyTorch Geometric example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/link_pred.py)\n",
+ "* [Graph Neural Networks: A Review of Methods and Applications](https://arxiv.org/pdf/1812.08434.pdf), Zhou et al.\n",
+ "2019\n",
+ "* [Link Prediction Based on Graph Neural Networks](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf), Zhang and Chen, 2018."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3212535a",
+ "metadata": {
+ "papermill": {
+ "duration": 0.011616,
+ "end_time": "2023-10-11T16:03:20.826996",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:20.815380",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Graph-level tasks: Graph classification\n",
+ "\n",
+ "Finally, in this part of the tutorial, we will have a closer look at how to apply GNNs to the task of graph classification.\n",
+ "The goal is to classify an entire graph instead of single nodes or edges.\n",
+ "Therefore, we are also given a dataset of multiple graphs that we need to classify based on some structural graph properties.\n",
+ "The most common task for graph classification is molecular property prediction, in which molecules are represented as graphs.\n",
+ "Each atom is linked to a node, and edges in the graph are the bonds between atoms.\n",
+ "For example, look at the figure below.\n",
+ "\n",
+ "
\n",
+ "\n",
+ "On the left, we have an arbitrary, small molecule with different atoms, whereas the right part of the image shows the graph representation.\n",
+ "The atom types are abstracted as node features (e.g. a one-hot vector), and the different bond types are used as edge features.\n",
+ "For simplicity, we will neglect the edge attributes in this tutorial, but you can include by using methods like the\n",
+ "[Relational Graph Convolution](https://arxiv.org/abs/1703.06103) that uses a different weight matrix for each edge type.\n",
+ "\n",
+ "The dataset we will use below is called the MUTAG dataset.\n",
+ "It is a common small benchmark for graph classification algorithms, and contain 188 graphs with 18 nodes\n",
+ "and 20 edges on average for each graph.\n",
+ "The graph nodes have 7 different labels/atom types, and the binary graph labels represent \"their mutagenic effect\n",
+ "on a specific gram negative bacterium\" (the specific meaning of the labels are not too important here).\n",
+ "The dataset is part of a large collection of different graph classification datasets, known as the\n",
+ "[TUDatasets](https://chrsmrrs.github.io/datasets/), which is directly accessible\n",
+ "via `torch_geometric.datasets.TUDataset` ([documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.TUDataset)) in PyTorch Geometric.\n",
+ "We can load the dataset below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "916022ac",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:20.856516Z",
+ "iopub.status.busy": "2023-10-11T16:03:20.856076Z",
+ "iopub.status.idle": "2023-10-11T16:03:21.889799Z",
+ "shell.execute_reply": "2023-10-11T16:03:21.889240Z"
+ },
+ "papermill": {
+ "duration": 1.053683,
+ "end_time": "2023-10-11T16:03:21.893346",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:20.839663",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Extracting /__w/15/s/.datasets/MUTAG/MUTAG.zip\n",
+ "Processing...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "tu_dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name=\"MUTAG\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cc5fd7be",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012581,
+ "end_time": "2023-10-11T16:03:21.923932",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:21.911351",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Let's look at some statistics for the dataset:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "f1857455",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:21.949045Z",
+ "iopub.status.busy": "2023-10-11T16:03:21.948703Z",
+ "iopub.status.idle": "2023-10-11T16:03:21.957830Z",
+ "shell.execute_reply": "2023-10-11T16:03:21.957240Z"
+ },
+ "papermill": {
+ "duration": 0.023647,
+ "end_time": "2023-10-11T16:03:21.959503",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:21.935856",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Data object: Data(x=[3371, 7], edge_index=[2, 7442], edge_attr=[7442, 4], y=[188])\n",
+ "Length: 188\n",
+ "Average label: 0.66\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
+ " warnings.warn(msg)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Data object:\", tu_dataset.data)\n",
+ "print(\"Length:\", len(tu_dataset))\n",
+ "print(\"Average label: %4.2f\" % (tu_dataset.data.y.float().mean().item()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "68a2b8c5",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012255,
+ "end_time": "2023-10-11T16:03:21.988097",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:21.975842",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "The first line shows how the dataset stores different graphs.\n",
+ "The nodes, edges, and labels of each graph are concatenated to one tensor, and the dataset stores the indices\n",
+ "where to split the tensors correspondingly.\n",
+ "The length of the dataset is the number of graphs we have, and the \"average label\"\n",
+ "denotes the percentage of the graph with label 1.\n",
+ "As long as the percentage is in the range of 0.5, we have a relatively balanced dataset.\n",
+ "It happens quite often that graph datasets are very imbalanced, hence checking the class balance\n",
+ "is always a good thing to do.\n",
+ "\n",
+ "Next, we will split our dataset into a training and test part.\n",
+ "Note that we do not use a validation set this time because of the small size of the dataset.\n",
+ "Therefore, our model might overfit slightly on the validation set due to the noise of the evaluation,\n",
+ "but we still get an estimate of the performance on untrained data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "29d0a9c6",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.014982Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.014390Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.018493Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.017953Z"
+ },
+ "papermill": {
+ "duration": 0.021347,
+ "end_time": "2023-10-11T16:03:22.022150",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.000803",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(42)\n",
+ "tu_dataset.shuffle()\n",
+ "train_dataset = tu_dataset[:150]\n",
+ "test_dataset = tu_dataset[150:]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0558ab51",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012813,
+ "end_time": "2023-10-11T16:03:22.047078",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.034265",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "When using a data loader, we encounter a problem with batching $N$ graphs.\n",
+ "Each graph in the batch can have a different number of nodes and edges, and hence we would require a lot of padding to obtain a single tensor.\n",
+ "Torch geometric uses a different, more efficient approach: we can view the $N$ graphs in a batch as a single large graph with concatenated node and edge list.\n",
+ "As there is no edge between the $N$ graphs, running GNN layers on the large graph gives us the same output as running the GNN on each graph separately.\n",
+ "Visually, this batching strategy is visualized below (figure credit - PyTorch Geometric team,\n",
+ "[tutorial here](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb)).\n",
+ "\n",
+ "
\n",
+ "\n",
+ "The adjacency matrix is zero for any nodes that come from two different graphs, and otherwise according to the adjacency matrix of the individual graph.\n",
+ "Luckily, this strategy is already implemented in torch geometric, and hence we can use the corresponding data loader:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "137c9f19",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.077104Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.072997Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.081741Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.081006Z"
+ },
+ "papermill": {
+ "duration": 0.02352,
+ "end_time": "2023-10-11T16:03:22.083057",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.059537",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "graph_train_loader = geom_data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
+ "graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=BATCH_SIZE) # Additional loader for a larger datasets\n",
+ "graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=BATCH_SIZE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cfbc9eee",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012036,
+ "end_time": "2023-10-11T16:03:22.106505",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.094469",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Let's load a batch below to see the batching in action:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "30662c4c",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.132189Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.132000Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.144650Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.143918Z"
+ },
+ "papermill": {
+ "duration": 0.026959,
+ "end_time": "2023-10-11T16:03:22.145991",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.119032",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Batch: DataBatch(edge_index=[2, 1512], x=[687, 7], edge_attr=[1512, 4], y=[38], batch=[687], ptr=[39])\n",
+ "Labels: tensor([1, 1, 1, 0, 0, 0, 1, 1, 1, 0])\n",
+ "Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2])\n"
+ ]
+ }
+ ],
+ "source": [
+ "batch = next(iter(graph_test_loader))\n",
+ "print(\"Batch:\", batch)\n",
+ "print(\"Labels:\", batch.y[:10])\n",
+ "print(\"Batch indices:\", batch.batch[:40])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21440b10",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.012337,
+ "end_time": "2023-10-11T16:03:22.171681",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.159344",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "We have 38 graphs stacked together for the test dataset.\n",
+ "The batch indices, stored in `batch`, show that the first 12 nodes belong to the first graph,\n",
+ "the next 22 to the second graph, and so on.\n",
+ "These indices are important for performing the final prediction.\n",
+ "To perform a prediction over a whole graph, we usually perform a pooling operation over all nodes after running the GNN model.\n",
+ "In this case, we will use the average pooling.\n",
+ "Hence, we need to know which nodes should be included in which average pool.\n",
+ "Using this pooling, we can already create our graph network below.\n",
+ "Specifically, we reuse our class `GNNModel` from before,\n",
+ "and simply add an average pool and single linear layer for the graph prediction task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "a0b2ff13",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.200230Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.199343Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.206433Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.205598Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.028572,
+ "end_time": "2023-10-11T16:03:22.211969",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.183397",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class GraphGNNModel(nn.Module):\n",
+ " def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):\n",
+ " \"\"\"GraphGNNModel.\n",
+ "\n",
+ " Args:\n",
+ " c_in: Dimension of input features\n",
+ " c_hidden: Dimension of hidden features\n",
+ " c_out: Dimension of output features (usually number of classes)\n",
+ " dp_rate_linear: Dropout rate before the linear layer (usually much higher than inside the GNN)\n",
+ " kwargs: Additional arguments for the GNNModel object\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " self.GNN = GNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_hidden, **kwargs) # Not our prediction output yet!\n",
+ " self.head = nn.Sequential(nn.Dropout(dp_rate_linear), nn.Linear(c_hidden, c_out))\n",
+ "\n",
+ " def forward(self, x, edge_index, batch_idx):\n",
+ " \"\"\"Forward.\n",
+ "\n",
+ " Args:\n",
+ " x: Input features per node\n",
+ " edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)\n",
+ " batch_idx: Index of batch element for each node\n",
+ " \"\"\"\n",
+ " x = self.GNN(x, edge_index)\n",
+ " x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling\n",
+ " x = self.head(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b553c870",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.013432,
+ "end_time": "2023-10-11T16:03:22.242667",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.229235",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Finally, we can create a PyTorch Lightning module to handle the training.\n",
+ "It is similar to the modules we have seen before and does nothing surprising in terms of training.\n",
+ "As we have a binary classification task, we use the Binary Cross Entropy loss."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "033fde2f",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.269885Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.269100Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.284187Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.283350Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.030592,
+ "end_time": "2023-10-11T16:03:22.285650",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.255058",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class GraphLevelGNN(L.LightningModule):\n",
+ " def __init__(self, **model_kwargs):\n",
+ " super().__init__()\n",
+ " # Saving hyperparameters\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " self.model = GraphGNNModel(**model_kwargs)\n",
+ " self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()\n",
+ "\n",
+ " def forward(self, data, mode=\"train\"):\n",
+ " x, edge_index, batch_idx = data.x, data.edge_index, data.batch\n",
+ " x = self.model(x, edge_index, batch_idx)\n",
+ " x = x.squeeze(dim=-1)\n",
+ "\n",
+ " if self.hparams.c_out == 1:\n",
+ " preds = (x > 0).float()\n",
+ " data.y = data.y.float()\n",
+ " else:\n",
+ " preds = x.argmax(dim=-1)\n",
+ " loss = self.loss_module(x, data.y)\n",
+ " acc = (preds == data.y).sum().float() / preds.shape[0]\n",
+ " return loss, acc\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " # High lr because of small dataset and small model\n",
+ " optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0)\n",
+ " return optimizer\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " loss, acc = self.forward(batch, mode=\"train\")\n",
+ " self.log(\"train_loss\", loss)\n",
+ " self.log(\"train_acc\", acc)\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " _, acc = self.forward(batch, mode=\"val\")\n",
+ " self.log(\"val_acc\", acc)\n",
+ "\n",
+ " def test_step(self, batch, batch_idx):\n",
+ " _, acc = self.forward(batch, mode=\"test\")\n",
+ " self.log(\"test_acc\", acc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e1c854a2",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.012575,
+ "end_time": "2023-10-11T16:03:22.311353",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.298778",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Below we train the model on our dataset. It resembles the typical training functions we have seen so far."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "d031d1d4",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.337771Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.337052Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.353709Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.352823Z"
+ },
+ "papermill": {
+ "duration": 0.031695,
+ "end_time": "2023-10-11T16:03:22.355074",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.323379",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def train_graph_classifier(model_name, **model_kwargs):\n",
+ " L.seed_everything(42)\n",
+ "\n",
+ " # Create a PyTorch Lightning trainer with the generation callback\n",
+ " root_dir = os.path.join(CHECKPOINT_PATH, \"GraphLevel\" + model_name)\n",
+ " os.makedirs(root_dir, exist_ok=True)\n",
+ " trainer = L.Trainer(\n",
+ " default_root_dir=root_dir,\n",
+ " callbacks=[ModelCheckpoint(save_weights_only=True, mode=\"max\", monitor=\"val_acc\")],\n",
+ " accelerator=\"cuda\",\n",
+ " devices=AVAIL_GPUS,\n",
+ " max_epochs=500,\n",
+ " enable_progress_bar=False,\n",
+ " )\n",
+ " trainer.logger._default_hp_metric = None\n",
+ "\n",
+ " # Check whether pretrained model exists. If yes, load it and skip training\n",
+ " pretrained_filename = os.path.join(CHECKPOINT_PATH, \"GraphLevel%s.ckpt\" % model_name)\n",
+ " if os.path.isfile(pretrained_filename):\n",
+ " print(\"Found pretrained model, loading...\")\n",
+ " model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)\n",
+ " else:\n",
+ " L.seed_everything(42)\n",
+ " model = GraphLevelGNN(\n",
+ " c_in=tu_dataset.num_node_features,\n",
+ " c_out=1 if tu_dataset.num_classes == 2 else tu_dataset.num_classes,\n",
+ " **model_kwargs,\n",
+ " )\n",
+ " trainer.fit(model, graph_train_loader, graph_val_loader)\n",
+ " model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)\n",
+ "\n",
+ " # Test best model on validation and test set\n",
+ " train_result = trainer.test(model, dataloaders=graph_train_loader, verbose=False)\n",
+ " test_result = trainer.test(model, dataloaders=graph_test_loader, verbose=False)\n",
+ " result = {\"test\": test_result[0][\"test_acc\"], \"train\": train_result[0][\"test_acc\"]}\n",
+ " return model, result"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "97cb8ad3",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012021,
+ "end_time": "2023-10-11T16:03:22.379407",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.367386",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Finally, let's perform the training and testing.\n",
+ "Feel free to experiment with different GNN layers, hyperparameters, etc."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "b139207e",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.405188Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.404648Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.515903Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.510992Z"
+ },
+ "papermill": {
+ "duration": 0.125803,
+ "end_time": "2023-10-11T16:03:22.517426",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.391623",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 42\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/GNNs/GraphLevelGraphConv.ckpt`\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:490: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
+ " rank_zero_warn(\n",
+ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
+ " warning_cache.warn(\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found pretrained model, loading...\n"
+ ]
+ }
+ ],
+ "source": [
+ "model, result = train_graph_classifier(\n",
+ " model_name=\"GraphConv\", c_hidden=256, layer_name=\"GraphConv\", num_layers=3, dp_rate_linear=0.5, dp_rate=0.0\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "1f6c10e3",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-11T16:03:22.551971Z",
+ "iopub.status.busy": "2023-10-11T16:03:22.551207Z",
+ "iopub.status.idle": "2023-10-11T16:03:22.556279Z",
+ "shell.execute_reply": "2023-10-11T16:03:22.555416Z"
+ },
+ "papermill": {
+ "duration": 0.026658,
+ "end_time": "2023-10-11T16:03:22.557748",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.531090",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train performance: 92.67%\n",
+ "Test performance: 92.11%\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Train performance: %4.2f%%\" % (100.0 * result[\"train\"]))\n",
+ "print(\"Test performance: %4.2f%%\" % (100.0 * result[\"test\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ce1459df",
+ "metadata": {
+ "papermill": {
+ "duration": 0.016168,
+ "end_time": "2023-10-11T16:03:22.587068",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.570900",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "The test performance shows that we obtain quite good scores on an unseen part of the dataset.\n",
+ "It should be noted that as we have been using the test set for validation as well, we might have overfitted slightly to this set.\n",
+ "Nevertheless, the experiment shows us that GNNs can be indeed powerful to predict the properties of graphs and/or molecules."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9a2e0201",
+ "metadata": {
+ "papermill": {
+ "duration": 0.016881,
+ "end_time": "2023-10-11T16:03:22.617035",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.600154",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Conclusion\n",
+ "\n",
+ "In this tutorial, we have seen the application of neural networks to graph structures.\n",
+ "We looked at how a graph can be represented (adjacency matrix or edge list),\n",
+ "and discussed the implementation of common graph layers: GCN and GAT.\n",
+ "The implementations showed the practical side of the layers, which is often easier than the theory.\n",
+ "Finally, we experimented with different tasks, on node-, edge- and graph-level.\n",
+ "Overall, we have seen that including graph information in the predictions can be crucial for achieving high performance.\n",
+ "There are a lot of applications that benefit from GNNs,\n",
+ "and the importance of these networks will likely increase over the next years."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1ef601bc",
+ "metadata": {
+ "papermill": {
+ "duration": 0.013155,
+ "end_time": "2023-10-11T16:03:22.643763",
+ "exception": false,
+ "start_time": "2023-10-11T16:03:22.630608",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Congratulations - Time to Join the Community!\n",
+ "\n",
+ "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n",
+ "movement, you can do so in the following ways!\n",
+ "\n",
+ "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n",
+ "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n",
+ "tools we're building.\n",
+ "\n",
+ "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n",
+ "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n",
+ "and share your interests in `#general` channel\n",
+ "\n",
+ "\n",
+ "### Contributions !\n",
+ "The best way to contribute to our community is to become a code contributor! At any time you can go to\n",
+ "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n",
+ "GitHub Issues page and filter for \"good first issue\".\n",
+ "\n",
+ "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* You can also contribute your own notebooks with useful examples !\n",
+ "\n",
+ "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
+ "\n",
+ "[{height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "colab_type,colab,id,-all",
+ "formats": "ipynb,py:percent",
+ "main_language": "python"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ },
+ "papermill": {
+ "default_parameters": {},
+ "duration": 18.450614,
+ "end_time": "2023-10-11T16:03:23.982888",
+ "environment_variables": {},
+ "exception": null,
+ "input_path": "course_UvA-DL/06-graph-neural-networks/GNN_overview.ipynb",
+ "output_path": ".notebooks/course_UvA-DL/06-graph-neural-networks.ipynb",
+ "parameters": {},
+ "start_time": "2023-10-11T16:03:05.532274",
+ "version": "2.4.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/.notebooks/course_UvA-DL/06-graph-neural-networks.jpg b/.notebooks/course_UvA-DL/06-graph-neural-networks.jpg
new file mode 100644
index 000000000..0cda6bd18
Binary files /dev/null and b/.notebooks/course_UvA-DL/06-graph-neural-networks.jpg differ
diff --git a/.notebooks/course_UvA-DL/06-graph-neural-networks.yaml b/.notebooks/course_UvA-DL/06-graph-neural-networks.yaml
new file mode 100644
index 000000000..49e37dd3d
--- /dev/null
+++ b/.notebooks/course_UvA-DL/06-graph-neural-networks.yaml
@@ -0,0 +1,61 @@
+title: "Tutorial 6: Basics of Graph Neural Networks"
+author: Phillip Lippe
+created: 2021-06-07
+updated: 2023-03-14
+license: CC BY-SA
+build: 0
+tags:
+ - Graph
+description: "In this tutorial, we will discuss the application of neural networks
+ on graphs.
+
+ Graph Neural Networks (GNNs) have recently gained increasing popularity in both
+ applications and research,
+
+ including domains such as social networks, knowledge graphs, recommender systems,
+ and bioinformatics.
+
+ While the theory and math behind GNNs might first seem complicated,
+
+ the implementation of those models is quite simple and helps in understanding the
+ methodology.
+
+ Therefore, we will discuss the implementation of basic network layers of a GNN,
+
+ namely graph convolutions, and attention layers.
+
+ Finally, we will apply a GNN on semi-supervised node classification and molecule
+ categorization.
+
+ This notebook is part of a lecture series on Deep Learning at the University of
+ Amsterdam.
+
+ The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
+
+ "
+requirements:
+ - torch-scatter
+ - torch-sparse
+ - torch-cluster
+ - torch-spline-conv
+ - torch-geometric
+ - lightning>=2.0.0
+pip__find-link:
+ - https://pytorch-geometric.com/whl/torch-%(TORCH_MAJOR_DOT_MINOR)s.0+%(DEVICE)s.html
+accelerator:
+ - GPU
+environment:
+ - torch-spline-conv==1.2.2+pt20cu118
+ - matplotlib==3.8.0
+ - ipython==8.16.1
+ - lightning==2.0.9.post0
+ - torch-geometric==2.3.1
+ - torch-cluster==1.6.2+pt20cu118
+ - torchmetrics==1.2.0
+ - setuptools==68.2.2
+ - torch-sparse==0.6.18+pt20cu118
+ - urllib3==2.0.6
+ - pytorch-lightning==2.0.9.post0
+ - torch==2.0.1+cu118
+ - torch-scatter==2.1.2+pt20cu118
+published: "2023-10-11T16:03:26.026950"
diff --git a/.notebooks/course_UvA-DL/07-deep-energy-based-generative-models.yaml b/.notebooks/course_UvA-DL/07-deep-energy-based-generative-models.yaml
index e627ffc61..ef69d08b0 100644
--- a/.notebooks/course_UvA-DL/07-deep-energy-based-generative-models.yaml
+++ b/.notebooks/course_UvA-DL/07-deep-energy-based-generative-models.yaml
@@ -1,12 +1,12 @@
-title: 'Tutorial 7: Deep Energy-Based Generative Models'
+title: "Tutorial 7: Deep Energy-Based Generative Models"
author: Phillip Lippe
created: 2021-07-12
updated: 2023-03-14
license: CC BY-SA
build: 1
tags:
-- Image
-description: 'In this tutorial, we will look at energy-based deep learning models,
+ - Image
+description: "In this tutorial, we will look at energy-based deep learning models,
and focus on their application as generative models.
Energy models have been a popular tool before the huge deep learning hype around
@@ -36,24 +36,24 @@ description: 'In this tutorial, we will look at energy-based deep learning model
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- tensorboard
-- pytorch-lightning>=2.0.0
+ - torchvision
+ - matplotlib
+ - tensorboard
+ - pytorch-lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- urllib3==2.2.2
-- torchvision==0.15.2
-- pytorch-lightning==2.0.9.post0
-- tensorboard==2.17.0
-- ipython==8.16.1
-- setuptools==69.0.3
-- torchmetrics==1.2.1
-- numpy==1.26.4
-- torch==2.0.1
-- matplotlib==3.8.4
-published: '2024-07-20T00:33:30.041519'
+ - urllib3==2.2.2
+ - torchvision==0.15.2
+ - pytorch-lightning==2.0.9.post0
+ - tensorboard==2.17.0
+ - ipython==8.16.1
+ - setuptools==69.0.3
+ - torchmetrics==1.2.1
+ - numpy==1.26.4
+ - torch==2.0.1
+ - matplotlib==3.8.4
+published: "2024-07-20T00:33:30.041519"
diff --git a/.notebooks/course_UvA-DL/08-deep-autoencoders.yaml b/.notebooks/course_UvA-DL/08-deep-autoencoders.yaml
index 4a0326de8..8d0284151 100644
--- a/.notebooks/course_UvA-DL/08-deep-autoencoders.yaml
+++ b/.notebooks/course_UvA-DL/08-deep-autoencoders.yaml
@@ -1,11 +1,11 @@
-title: 'Tutorial 8: Deep Autoencoders'
+title: "Tutorial 8: Deep Autoencoders"
author: Phillip Lippe
created: 2021-07-12
updated: 2023-03-14
license: CC BY-SA
build: 0
tags:
-- Image
+ - Image
description: 'In this tutorial, we will take a closer look at autoencoders (AE).
Autoencoders are trained on encoding input data such as images into a smaller feature
@@ -37,25 +37,25 @@ description: 'In this tutorial, we will take a closer look at autoencoders (AE).
'
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
-- tensorboard
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
+ - tensorboard
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- torchvision==0.15.2
-- tensorboard==2.17.0
-- torch==2.0.1
-- setuptools==69.0.3
-- urllib3==2.2.2
-- numpy==1.26.4
-- ipython==8.16.1
-- pytorch-lightning==2.0.9.post0
-- torchmetrics==1.2.1
-- seaborn==0.13.2
-- lightning==2.3.3
-- matplotlib==3.8.4
-published: '2024-07-19T19:52:06.169431'
+ - torchvision==0.15.2
+ - tensorboard==2.17.0
+ - torch==2.0.1
+ - setuptools==69.0.3
+ - urllib3==2.2.2
+ - numpy==1.26.4
+ - ipython==8.16.1
+ - pytorch-lightning==2.0.9.post0
+ - torchmetrics==1.2.1
+ - seaborn==0.13.2
+ - lightning==2.3.3
+ - matplotlib==3.8.4
+published: "2024-07-19T19:52:06.169431"
diff --git a/.notebooks/course_UvA-DL/09-normalizing-flows.yaml b/.notebooks/course_UvA-DL/09-normalizing-flows.yaml
index 64bfee1a3..e8bc57075 100644
--- a/.notebooks/course_UvA-DL/09-normalizing-flows.yaml
+++ b/.notebooks/course_UvA-DL/09-normalizing-flows.yaml
@@ -1,12 +1,12 @@
-title: 'Tutorial 9: Normalizing Flows for Image Modeling'
+title: "Tutorial 9: Normalizing Flows for Image Modeling"
author: Phillip Lippe
created: 2021-06-07
updated: 2023-03-14
license: CC BY-SA
build: 0
tags:
-- Image
-description: 'In this tutorial, we will take a closer look at complex, deep normalizing
+ - Image
+description: "In this tutorial, we will take a closer look at complex, deep normalizing
flows.
The most popular, current application of deep normalizing flows is to model datasets
@@ -39,27 +39,27 @@ description: 'In this tutorial, we will take a closer look at complex, deep norm
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- tabulate
-- lightning>=2.0.0
+ - torchvision
+ - matplotlib
+ - seaborn
+ - tabulate
+ - lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- matplotlib==3.8.4
-- torchmetrics==1.2.1
-- torchvision==0.15.2
-- numpy==1.26.4
-- pytorch-lightning==2.0.9.post0
-- setuptools==69.0.3
-- urllib3==2.2.2
-- tabulate==0.9.0
-- lightning==2.3.3
-- seaborn==0.13.2
-- torch==2.0.1
-- ipython==8.16.1
-published: '2024-07-19T19:56:56.208867'
+ - matplotlib==3.8.4
+ - torchmetrics==1.2.1
+ - torchvision==0.15.2
+ - numpy==1.26.4
+ - pytorch-lightning==2.0.9.post0
+ - setuptools==69.0.3
+ - urllib3==2.2.2
+ - tabulate==0.9.0
+ - lightning==2.3.3
+ - seaborn==0.13.2
+ - torch==2.0.1
+ - ipython==8.16.1
+published: "2024-07-19T19:56:56.208867"
diff --git a/.notebooks/course_UvA-DL/10-autoregressive-image-modeling.yaml b/.notebooks/course_UvA-DL/10-autoregressive-image-modeling.yaml
index 33ccc7baf..ea1205ad9 100644
--- a/.notebooks/course_UvA-DL/10-autoregressive-image-modeling.yaml
+++ b/.notebooks/course_UvA-DL/10-autoregressive-image-modeling.yaml
@@ -1,12 +1,12 @@
-title: 'Tutorial 10: Autoregressive Image Modeling'
+title: "Tutorial 10: Autoregressive Image Modeling"
author: Phillip Lippe
created: 2021-07-12
updated: 2023-03-14
license: CC BY-SA
build: 0
tags:
-- Image
-description: 'In this tutorial, we implement an autoregressive likelihood model for
+ - Image
+description: "In this tutorial, we implement an autoregressive likelihood model for
the task of image modeling.
Autoregressive models are naturally strong generative models that constitute one
@@ -24,24 +24,24 @@ description: 'In this tutorial, we implement an autoregressive likelihood model
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
accelerator:
-- GPU
+ - GPU
environment:
-- torchvision==0.15.2
-- torch==2.0.1
-- seaborn==0.13.2
-- torchmetrics==1.2.1
-- numpy==1.26.4
-- lightning==2.3.3
-- pytorch-lightning==2.0.9.post0
-- setuptools==69.0.3
-- matplotlib==3.8.4
-- ipython==8.16.1
-- urllib3==2.2.2
-published: '2024-07-19T20:11:08.643733'
+ - torchvision==0.15.2
+ - torch==2.0.1
+ - seaborn==0.13.2
+ - torchmetrics==1.2.1
+ - numpy==1.26.4
+ - lightning==2.3.3
+ - pytorch-lightning==2.0.9.post0
+ - setuptools==69.0.3
+ - matplotlib==3.8.4
+ - ipython==8.16.1
+ - urllib3==2.2.2
+published: "2024-07-19T20:11:08.643733"
diff --git a/.notebooks/course_UvA-DL/11-vision-transformer.yaml b/.notebooks/course_UvA-DL/11-vision-transformer.yaml
index 4c18e0179..848b01a38 100644
--- a/.notebooks/course_UvA-DL/11-vision-transformer.yaml
+++ b/.notebooks/course_UvA-DL/11-vision-transformer.yaml
@@ -1,9 +1,9 @@
-title: 'Tutorial 11: Vision Transformers'
+title: "Tutorial 11: Vision Transformers"
author: Phillip Lippe
created: 2021-08-21
updated: 2023-03-14
license: CC BY-SA
-description: 'In this tutorial, we will take a closer look at a recent new trend:
+description: "In this tutorial, we will take a closer look at a recent new trend:
Transformers for Computer Vision.
Since [Alexey Dosovitskiy et al.](https://openreview.net/pdf?id=YicbFdNTTy) successfully
@@ -25,29 +25,29 @@ description: 'In this tutorial, we will take a closer look at a recent new trend
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
tags:
-- Image
+ - Image
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
-- tensorboard
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
+ - tensorboard
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- torch==2.0.1
-- torchmetrics==1.2.1
-- torchvision==0.15.2
-- pytorch-lightning==2.0.9.post0
-- matplotlib==3.8.4
-- ipython==8.16.1
-- seaborn==0.13.2
-- lightning==2.3.3
-- tensorboard==2.17.0
-- numpy==1.26.4
-- urllib3==2.2.2
-- setuptools==69.0.3
-published: '2024-07-19T20:16:19.153871'
+ - torch==2.0.1
+ - torchmetrics==1.2.1
+ - torchvision==0.15.2
+ - pytorch-lightning==2.0.9.post0
+ - matplotlib==3.8.4
+ - ipython==8.16.1
+ - seaborn==0.13.2
+ - lightning==2.3.3
+ - tensorboard==2.17.0
+ - numpy==1.26.4
+ - urllib3==2.2.2
+ - setuptools==69.0.3
+published: "2024-07-19T20:16:19.153871"
diff --git a/.notebooks/course_UvA-DL/12-meta-learning.yaml b/.notebooks/course_UvA-DL/12-meta-learning.yaml
index 1de785017..e76648b15 100644
--- a/.notebooks/course_UvA-DL/12-meta-learning.yaml
+++ b/.notebooks/course_UvA-DL/12-meta-learning.yaml
@@ -1,12 +1,12 @@
-title: 'Tutorial 12: Meta-Learning - Learning to Learn'
+title: "Tutorial 12: Meta-Learning - Learning to Learn"
author: Phillip Lippe
created: 2021-08-21
updated: 2023-03-14
license: CC BY-SA
tags:
-- Few-shot-learning
-- MAML
-- ProtoNet
+ - Few-shot-learning
+ - MAML
+ - ProtoNet
description: 'In this tutorial, we will discuss algorithms that learn models which
can quickly adapt to new classes and/or tasks with few samples.
@@ -39,27 +39,27 @@ description: 'In this tutorial, we will discuss algorithms that learn models whi
'
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
-- tensorboard
-- scipy
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
+ - tensorboard
+ - scipy
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- matplotlib==3.8.4
-- numpy==1.26.4
-- ipython==8.16.1
-- torch==2.0.1
-- torchmetrics==1.2.1
-- torchvision==0.15.2
-- lightning==2.3.3
-- tensorboard==2.17.0
-- seaborn==0.13.2
-- urllib3==2.2.2
-- pytorch-lightning==2.0.9.post0
-- setuptools==69.0.3
-- scipy==1.14.0
-published: '2024-07-19T20:22:41.980859'
+ - matplotlib==3.8.4
+ - numpy==1.26.4
+ - ipython==8.16.1
+ - torch==2.0.1
+ - torchmetrics==1.2.1
+ - torchvision==0.15.2
+ - lightning==2.3.3
+ - tensorboard==2.17.0
+ - seaborn==0.13.2
+ - urllib3==2.2.2
+ - pytorch-lightning==2.0.9.post0
+ - setuptools==69.0.3
+ - scipy==1.14.0
+published: "2024-07-19T20:22:41.980859"
diff --git a/.notebooks/course_UvA-DL/13-contrastive-learning.yaml b/.notebooks/course_UvA-DL/13-contrastive-learning.yaml
index d59c45df9..75f5b16df 100644
--- a/.notebooks/course_UvA-DL/13-contrastive-learning.yaml
+++ b/.notebooks/course_UvA-DL/13-contrastive-learning.yaml
@@ -1,13 +1,13 @@
-title: 'Tutorial 13: Self-Supervised Contrastive Learning with SimCLR'
+title: "Tutorial 13: Self-Supervised Contrastive Learning with SimCLR"
author: Phillip Lippe
created: 2021-08-30
updated: 2023-03-14
license: CC BY-SA
tags:
-- Image
-- Self-Supervised
-- Contrastive-Learning
-description: 'In this tutorial, we will take a closer look at self-supervised contrastive
+ - Image
+ - Self-Supervised
+ - Contrastive-Learning
+description: "In this tutorial, we will take a closer look at self-supervised contrastive
learning.
Self-supervised learning, or also sometimes called unsupervised learning, describes
@@ -29,27 +29,27 @@ description: 'In this tutorial, we will take a closer look at self-supervised co
The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.
- '
+ "
requirements:
-- torchvision
-- matplotlib
-- seaborn
-- lightning>=2.0.0
-- tensorboard
+ - torchvision
+ - matplotlib
+ - seaborn
+ - lightning>=2.0.0
+ - tensorboard
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- pytorch-lightning==2.0.9.post0
-- ipython==8.16.1
-- urllib3==2.2.2
-- seaborn==0.13.2
-- torch==2.0.1
-- tensorboard==2.17.0
-- matplotlib==3.8.4
-- torchvision==0.15.2
-- lightning==2.3.3
-- numpy==1.26.4
-- torchmetrics==1.2.1
-- setuptools==69.0.3
-published: '2024-07-19T20:31:47.750096'
+ - pytorch-lightning==2.0.9.post0
+ - ipython==8.16.1
+ - urllib3==2.2.2
+ - seaborn==0.13.2
+ - torch==2.0.1
+ - tensorboard==2.17.0
+ - matplotlib==3.8.4
+ - torchvision==0.15.2
+ - lightning==2.3.3
+ - numpy==1.26.4
+ - torchmetrics==1.2.1
+ - setuptools==69.0.3
+published: "2024-07-19T20:31:47.750096"
diff --git a/.notebooks/course_UvA-DL/activation-functions.jpg b/.notebooks/course_UvA-DL/activation-functions.jpg
new file mode 100644
index 000000000..1b21f50dc
Binary files /dev/null and b/.notebooks/course_UvA-DL/activation-functions.jpg differ
diff --git a/.notebooks/course_UvA-DL/autoregressive-image-modeling.jpg b/.notebooks/course_UvA-DL/autoregressive-image-modeling.jpg
new file mode 100644
index 000000000..1ad5d6125
Binary files /dev/null and b/.notebooks/course_UvA-DL/autoregressive-image-modeling.jpg differ
diff --git a/.notebooks/course_UvA-DL/contrastive-learning.jpg b/.notebooks/course_UvA-DL/contrastive-learning.jpg
new file mode 100644
index 000000000..6e05cce2f
Binary files /dev/null and b/.notebooks/course_UvA-DL/contrastive-learning.jpg differ
diff --git a/.notebooks/course_UvA-DL/deep-autoencoders.jpg b/.notebooks/course_UvA-DL/deep-autoencoders.jpg
new file mode 100644
index 000000000..1b07169f1
Binary files /dev/null and b/.notebooks/course_UvA-DL/deep-autoencoders.jpg differ
diff --git a/.notebooks/course_UvA-DL/deep-energy-based-generative-models.jpg b/.notebooks/course_UvA-DL/deep-energy-based-generative-models.jpg
new file mode 100644
index 000000000..32cd9486c
Binary files /dev/null and b/.notebooks/course_UvA-DL/deep-energy-based-generative-models.jpg differ
diff --git a/.notebooks/course_UvA-DL/graph-neural-networks.jpg b/.notebooks/course_UvA-DL/graph-neural-networks.jpg
new file mode 100644
index 000000000..0cda6bd18
Binary files /dev/null and b/.notebooks/course_UvA-DL/graph-neural-networks.jpg differ
diff --git a/.notebooks/course_UvA-DL/inception-resnet-densenet.jpg b/.notebooks/course_UvA-DL/inception-resnet-densenet.jpg
new file mode 100644
index 000000000..a7e02050d
Binary files /dev/null and b/.notebooks/course_UvA-DL/inception-resnet-densenet.jpg differ
diff --git a/.notebooks/course_UvA-DL/initialization-and-optimization.jpg b/.notebooks/course_UvA-DL/initialization-and-optimization.jpg
new file mode 100644
index 000000000..e8d42d4c0
Binary files /dev/null and b/.notebooks/course_UvA-DL/initialization-and-optimization.jpg differ
diff --git a/.notebooks/course_UvA-DL/introduction-to-pytorch.jpg b/.notebooks/course_UvA-DL/introduction-to-pytorch.jpg
new file mode 100644
index 000000000..a56ca66f2
Binary files /dev/null and b/.notebooks/course_UvA-DL/introduction-to-pytorch.jpg differ
diff --git a/.notebooks/course_UvA-DL/meta-learning.jpg b/.notebooks/course_UvA-DL/meta-learning.jpg
new file mode 100644
index 000000000..4f8f6d9d9
Binary files /dev/null and b/.notebooks/course_UvA-DL/meta-learning.jpg differ
diff --git a/.notebooks/course_UvA-DL/normalizing-flows.jpg b/.notebooks/course_UvA-DL/normalizing-flows.jpg
new file mode 100644
index 000000000..9654f8acc
Binary files /dev/null and b/.notebooks/course_UvA-DL/normalizing-flows.jpg differ
diff --git a/.notebooks/course_UvA-DL/transformers-and-MH-attention.jpg b/.notebooks/course_UvA-DL/transformers-and-MH-attention.jpg
new file mode 100644
index 000000000..e644f9a01
Binary files /dev/null and b/.notebooks/course_UvA-DL/transformers-and-MH-attention.jpg differ
diff --git a/.notebooks/course_UvA-DL/vision-transformer.jpg b/.notebooks/course_UvA-DL/vision-transformer.jpg
new file mode 100644
index 000000000..c129c4bd1
Binary files /dev/null and b/.notebooks/course_UvA-DL/vision-transformer.jpg differ
diff --git a/.notebooks/flash_tutorials/electricity_forecasting.ipynb b/.notebooks/flash_tutorials/electricity_forecasting.ipynb
new file mode 100644
index 000000000..02d9cd680
--- /dev/null
+++ b/.notebooks/flash_tutorials/electricity_forecasting.ipynb
@@ -0,0 +1,5785 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "1f55a881",
+ "metadata": {
+ "papermill": {
+ "duration": 0.013565,
+ "end_time": "2023-01-05T10:40:45.957934",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:45.944369",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "\n",
+ "# Electricity Price Forecasting with N-BEATS\n",
+ "\n",
+ "* **Author:** Ethan Harris (ethan@pytorchlightning.ai)\n",
+ "* **License:** CC BY-SA\n",
+ "* **Generated:** 2023-01-05T11:40:10.253020\n",
+ "\n",
+ "This tutorial covers using Lightning Flash and it's integration with PyTorch Forecasting to train an autoregressive\n",
+ "model (N-BEATS) on hourly electricity pricing data. We show how the built-in interpretability tools from PyTorch\n",
+ "Forecasting can be used with Flash to plot the trend and daily seasonality in our data discovered by the model. We\n",
+ "also cover how features from PyTorch Lightning such as the learning rate finder can be used easily with Flash. As a\n",
+ "bonus, we show hat we can resample daily observations from the data to discover weekly trends instead.\n",
+ "\n",
+ "\n",
+ "---\n",
+ "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/flash_tutorials/electricity_forecasting.ipynb)\n",
+ "\n",
+ "Give us a ⭐ [on Github](https://www.github.com/Lightning-AI/lightning/)\n",
+ "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n",
+ "| Join us [on Slack](https://www.pytorchlightning.ai/community)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "787c12c2",
+ "metadata": {
+ "papermill": {
+ "duration": 0.0052,
+ "end_time": "2023-01-05T10:40:45.968640",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:45.963440",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Setup\n",
+ "This notebook requires some packages besides pytorch-lightning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "8ee74430",
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:45.980642Z",
+ "iopub.status.busy": "2023-01-05T10:40:45.980287Z",
+ "iopub.status.idle": "2023-01-05T10:40:47.313468Z",
+ "shell.execute_reply": "2023-01-05T10:40:47.312113Z"
+ },
+ "id": "LfrJLKPFyhsK",
+ "lines_to_next_cell": 0,
+ "papermill": {
+ "duration": 1.342556,
+ "end_time": "2023-01-05T10:40:47.316382",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:45.973826",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[31mERROR: Cannot install pytorch-lightning<1.9 and >=1.4 and pytorch-lightning==1.3.6 because these package versions have conflicting dependencies.\u001b[0m\u001b[31m\r\n",
+ "\u001b[0m\u001b[31mERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts\u001b[0m\u001b[31m\r\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "! pip install --quiet \"ipython[notebook]>=8.0.0, <8.9.0\" \"pytorch-lightning==1.3.6\" \"numpy<1.24\" \"setuptools==65.6.3\" \"pandas==1.1.5\" \"lightning-flash[tabular]>=0.6.0\" \"torch>=1.8.1, <1.14.0\" \"torchmetrics>=0.7, <0.12\" \"pytorch-lightning>=1.4, <1.9\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "98053ed4",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005336,
+ "end_time": "2023-01-05T10:40:47.331611",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:47.326275",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "In this tutorial we'll look at using [Lightning Flash](https://github.com/Lightning-AI/lightning-flash) and it's\n",
+ "integration with [PyTorch Forecasting](https://github.com/jdb78/pytorch-forecasting) for autoregressive modelling of\n",
+ "electricity prices using [the N-BEATS model](https://arxiv.org/abs/1905.10437).\n",
+ "We'll start by using N-BEATS to uncover daily patterns (seasonality) from hourly observations and then show how we can\n",
+ "resample daily averages to uncover weekly patterns too.\n",
+ "\n",
+ "Along the way, we'll see how the built-in tools from PyTorch Lightning, like the learning rate finder, can be used\n",
+ "seamlessly with Flash to help make the process of putting a model together as smooth as possible."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "21b86524",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:47.344308Z",
+ "iopub.status.busy": "2023-01-05T10:40:47.343934Z",
+ "iopub.status.idle": "2023-01-05T10:40:51.959866Z",
+ "shell.execute_reply": "2023-01-05T10:40:51.958637Z"
+ },
+ "papermill": {
+ "duration": 4.625739,
+ "end_time": "2023-01-05T10:40:51.962620",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:47.336881",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "import os\n",
+ "from typing import Any, Dict\n",
+ "\n",
+ "import flash\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "from flash.core.data.utils import download_data\n",
+ "from flash.core.integrations.pytorch_forecasting import convert_predictions\n",
+ "from flash.tabular.forecasting import TabularForecaster, TabularForecastingData\n",
+ "\n",
+ "DATASET_PATH = os.environ.get(\"PATH_DATASETS\", \"data/\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d39a5589",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005406,
+ "end_time": "2023-01-05T10:40:51.978659",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:51.973253",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Loading the data\n",
+ "\n",
+ "We'll use the Spanish hourly energy demand generation and weather data set from Kaggle:\n",
+ "https://www.kaggle.com/nicholasjhana/energy-consumption-generation-prices-and-weather\n",
+ "\n",
+ "First, download the data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "81b96803",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:51.991441Z",
+ "iopub.status.busy": "2023-01-05T10:40:51.990890Z",
+ "iopub.status.idle": "2023-01-05T10:40:52.532616Z",
+ "shell.execute_reply": "2023-01-05T10:40:52.531447Z"
+ },
+ "papermill": {
+ "duration": 0.551053,
+ "end_time": "2023-01-05T10:40:52.535125",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:51.984072",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d6c3d70d7f6448e4bbfb33186dda4220",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "/__w/8/s/.datasets/kaggle_electricity.zip: 0%| | 0/3903 [00:00, ?KB/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "download_data(\"https://pl-flash-data.s3.amazonaws.com/kaggle_electricity.zip\", DATASET_PATH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c6ba3192",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005497,
+ "end_time": "2023-01-05T10:40:52.551125",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:52.545628",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Data loading\n",
+ "\n",
+ "To load the data, we start by loading the CSV file into a pandas DataFrame:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "2c2da5a8",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:52.564537Z",
+ "iopub.status.busy": "2023-01-05T10:40:52.563451Z",
+ "iopub.status.idle": "2023-01-05T10:40:56.032575Z",
+ "shell.execute_reply": "2023-01-05T10:40:56.031350Z"
+ },
+ "papermill": {
+ "duration": 3.478778,
+ "end_time": "2023-01-05T10:40:56.035331",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:52.556553",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "df_energy_hourly = pd.read_csv(f\"{DATASET_PATH}/energy_dataset.csv\", parse_dates=[\"time\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "52d60a6e",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005525,
+ "end_time": "2023-01-05T10:40:56.051276",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.045751",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Before we can load the data into Flash, there are a few preprocessing steps we need to take.\n",
+ "The first preprocessing step is to set the `time` field as the index (formatted as a datetime).\n",
+ "The second step is to resample the data to the desired frequency in case it is different from the desired observation\n",
+ "frequency.\n",
+ "Since we are performing autoregressive modelling, we can remove all columns except for `\"price actual\"`.\n",
+ "\n",
+ "For the third preprocessing step, we need to create a \"time_idx\" column.\n",
+ "The \"time_idx\" column should contain integers corresponding to the observation index (e.g. in our case the difference\n",
+ "between two \"time_idx\" values is the number of hours between the observations).\n",
+ "To do this we convert the datetime to an index by taking the nanoseconds value and dividing by the number of\n",
+ "nanoseconds in a single unit of our chosen frequency.\n",
+ "We then subtract the minimum value so it starts at zero (although it would still work without this step).\n",
+ "\n",
+ "The Flash `TabularForecastingData` (which uses the `TimeSeriesDataSet` from PyTorch Forecasting internally) also\n",
+ "supports loading data from multiple time series (e.g. you may have electricity data from multiple countries).\n",
+ "To indicate that our data is all from the same series, we add a `constant` column with a constant value of zero.\n",
+ "\n",
+ "Here's the full preprocessing function:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "dbbb72cb",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:56.064169Z",
+ "iopub.status.busy": "2023-01-05T10:40:56.063784Z",
+ "iopub.status.idle": "2023-01-05T10:40:56.148263Z",
+ "shell.execute_reply": "2023-01-05T10:40:56.147131Z"
+ },
+ "papermill": {
+ "duration": 0.093787,
+ "end_time": "2023-01-05T10:40:56.150699",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.056912",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "\n",
+ "def preprocess(df: pd.DataFrame, frequency: str = \"1H\") -> pd.DataFrame:\n",
+ " df[\"time\"] = pd.to_datetime(df[\"time\"], utc=True, infer_datetime_format=True)\n",
+ " df.set_index(\"time\", inplace=True)\n",
+ "\n",
+ " df = df.resample(frequency).mean()\n",
+ "\n",
+ " df = df.filter([\"price actual\"])\n",
+ "\n",
+ " df[\"time_idx\"] = (df.index.view(int) / pd.Timedelta(frequency).value).astype(int)\n",
+ " df[\"time_idx\"] -= df[\"time_idx\"].min()\n",
+ "\n",
+ " df[\"constant\"] = 0\n",
+ "\n",
+ " return df\n",
+ "\n",
+ "\n",
+ "df_energy_hourly = preprocess(df_energy_hourly)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "70783cc4",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005551,
+ "end_time": "2023-01-05T10:40:56.166922",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.161371",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Creating the Flash DataModule\n",
+ "\n",
+ "Now, we can create a `TabularForecastingData`.\n",
+ "The role of the `TabularForecastingData` is to split up our time series into windows which include a region to encode\n",
+ "(of size `max_encoder_length`) and a region to predict (of size `max_prediction_length`) which will be used to compute\n",
+ "the loss.\n",
+ "The size of the prediction window should be chosen depending on the kinds of trends we would like our model to\n",
+ "uncover.\n",
+ "In our case, we are interested in how electricity prices change throughout the day, so a one day prediction window\n",
+ "(`max_prediction_length = 24`) makes sense here.\n",
+ "The size of the encoding window can vary, however, in the [N-BEATS paper](https://arxiv.org/abs/1905.10437) the\n",
+ "authors suggest using an encoder length of between two and ten times the prediction length.\n",
+ "We therefore choose two days (`max_encoder_length = 48`) as the encoder length."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "bc7eb4e6",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:56.179578Z",
+ "iopub.status.busy": "2023-01-05T10:40:56.179185Z",
+ "iopub.status.idle": "2023-01-05T10:40:56.355752Z",
+ "shell.execute_reply": "2023-01-05T10:40:56.354726Z"
+ },
+ "papermill": {
+ "duration": 0.185856,
+ "end_time": "2023-01-05T10:40:56.358263",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.172407",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py:3382: FutureWarning: Please pass an instantiated object of the `InputTransform` class. Passing the Class and keyword arguments separately has been deprecated since v0.8.0 and will be removed in v0.9.0.\n",
+ " if await self.run_code(code, result, async_=asy):\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_prediction_length = 24\n",
+ "max_encoder_length = 24 * 2\n",
+ "\n",
+ "training_cutoff = df_energy_hourly[\"time_idx\"].max() - max_prediction_length\n",
+ "\n",
+ "datamodule = TabularForecastingData.from_data_frame(\n",
+ " time_idx=\"time_idx\",\n",
+ " target=\"price actual\",\n",
+ " group_ids=[\"constant\"],\n",
+ " max_encoder_length=max_encoder_length,\n",
+ " max_prediction_length=max_prediction_length,\n",
+ " time_varying_unknown_reals=[\"price actual\"],\n",
+ " train_data_frame=df_energy_hourly[df_energy_hourly[\"time_idx\"] <= training_cutoff],\n",
+ " val_data_frame=df_energy_hourly,\n",
+ " batch_size=256,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0ac498a0",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005748,
+ "end_time": "2023-01-05T10:40:56.373036",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.367288",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Creating the Flash Task\n",
+ "\n",
+ "Now, we're ready to create a `TabularForecaster`.\n",
+ "The N-BEATS model has two primary hyper-parameters:`\"widths\"`, and `\"backcast_loss_ratio\"`.\n",
+ "In the [PyTorch Forecasting Documentation](https://pytorch-forecasting.readthedocs.io/en/latest/api/pytorch_forecasting.models.nbeats.NBeats.html),\n",
+ "the authors recommend using `\"widths\"` of `[32, 512]`.\n",
+ "In order to prevent overfitting with smaller datasets, a good rule of thumb is to limit the number of parameters of\n",
+ "your model.\n",
+ "For this reason, we use `\"widths\"` of `[16, 256]`.\n",
+ "\n",
+ "To understand the `\"backcast_loss_ratio\"`, let's take a look at this diagram of the model taken from\n",
+ "[the arXiv paper](https://arxiv.org/abs/1905.10437):\n",
+ "\n",
+ "\n",
+ "\n",
+ "Each 'block' within the N-BEATS architecture includes a forecast output and a backcast which can each yield their own\n",
+ "loss.\n",
+ "The `\"backcast_loss_ratio\"` is the ratio of the backcast loss to the forecast loss.\n",
+ "A value of `1.0` means that the loss function is simply the sum of the forecast and backcast losses."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d08071c5",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:56.386426Z",
+ "iopub.status.busy": "2023-01-05T10:40:56.386050Z",
+ "iopub.status.idle": "2023-01-05T10:40:56.428474Z",
+ "shell.execute_reply": "2023-01-05T10:40:56.427753Z"
+ },
+ "papermill": {
+ "duration": 0.051859,
+ "end_time": "2023-01-05T10:40:56.430675",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.378816",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/numpy/lib/nanfunctions.py:1217: RuntimeWarning: All-NaN slice encountered\n",
+ " r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,\n",
+ "Using 'n_beats' provided by jdb78/PyTorch-Forecasting (https://github.com/jdb78/pytorch-forecasting).\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pytorch_forecasting/models/nbeats/sub_modules.py:154: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.)\n",
+ " coefficients = torch.tensor([backcast_linspace ** i for i in range(thetas_dim)], dtype=torch.float32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = TabularForecaster(\n",
+ " datamodule.parameters, backbone=\"n_beats\", backbone_kwargs={\"widths\": [16, 256], \"backcast_loss_ratio\": 1.0}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "418a6bfa",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005769,
+ "end_time": "2023-01-05T10:40:56.447284",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.441515",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Finding the learning rate\n",
+ "\n",
+ "Tabular models can be particularly sensitive to the choice of learning rate.\n",
+ "Helpfully, PyTorch Lightning provides a built-in learning rate finder that suggests a suitable learning rate\n",
+ "automatically.\n",
+ "To use it, we first create our Trainer.\n",
+ "We apply gradient clipping (a common technique for tabular tasks) with ``gradient_clip_val=0.01`` in order to help\n",
+ "prevent our model from over-fitting.\n",
+ "Here's how to find the learning rate:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "9c832c85",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T10:40:56.463601Z",
+ "iopub.status.busy": "2023-01-05T10:40:56.463195Z",
+ "iopub.status.idle": "2023-01-05T10:41:12.532450Z",
+ "shell.execute_reply": "2023-01-05T10:41:12.531584Z"
+ },
+ "papermill": {
+ "duration": 16.07838,
+ "end_time": "2023-01-05T10:41:12.533963",
+ "exception": false,
+ "start_time": "2023-01-05T10:40:56.455583",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True, used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " | Name | Type | Params\n",
+ "------------------------------------------------------------\n",
+ "0 | train_metrics | ModuleDict | 0 \n",
+ "1 | val_metrics | ModuleDict | 0 \n",
+ "2 | test_metrics | ModuleDict | 0 \n",
+ "3 | adapter | PyTorchForecastingAdapter | 454 K \n",
+ "------------------------------------------------------------\n",
+ "454 K Trainable params\n",
+ "0 Non-trainable params\n",
+ "454 K Total params\n",
+ "1.820 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " warnings.warn(*args, **kwargs)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " warnings.warn(*args, **kwargs)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5d6b82d030ed40c389c229b466a00304",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Finding best initial lr: 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LR finder stopped early after 86 steps due to diverging loss.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Restored states from the checkpoint file at /__w/8/s/lr_find_temp_model.ckpt\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Suggested learning rate: 0.0007943282347242816\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAG4CAYAAABfDw16AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2aUlEQVR4nO3de3xU9Z3/8ffM5MolExIhIUsC2CIgiCg3o7gimxpZ13KJl7psBeWhro1YYGlrVqXe411qDVAsP0Qrq2KV6raF1QioNCDEFW9rAIsQCwkqTQLRTOZyfn9k5iRDAoSQ5JyTvJ6Px3kkcy7f+QyHkA+f8znn6zIMwxAAAIADua0OAAAAoK1IZAAAgGORyAAAAMcikQEAAI5FIgMAAByLRAYAADgWiQwAAHAsEhkAAOBYJDIAAMCxSGQAAIBjWZrIBINB3XnnnRo8eLASExP1ve99T/fee6+azppgGIYWLVqk/v37KzExUTk5Odq1a5eFUQMAALuIsfLNH3roIS1dulSrVq3SiBEjtH37dl133XXyer269dZbJUkPP/ywnnzySa1atUqDBw/WnXfeqdzcXH366adKSEg44XuEQiHt379fvXv3lsvl6uiPBAAA2oFhGDp8+LAyMjLkdh+n7mJY6LLLLjOuv/76qHUzZswwZs6caRiGYYRCISM9Pd145JFHzO1VVVVGfHy88V//9V+teo/y8nJDEgsLCwsLC4sDl/Ly8uP+nre0InP++edr+fLl2rlzp8444wzt2LFD7777rh5//HFJ0p49e1RRUaGcnBzzGK/XqwkTJqikpEQ/+tGPmo3p8/nk8/nM10b4MlV5ebmSkpI6+BMBAID2UFNTo8zMTPXu3fu4+1mayNx2222qqanRsGHD5PF4FAwGdf/992vmzJmSpIqKCklSWlpa1HFpaWnmtqMVFhbq7rvvbrY+KSmJRAYAAIc5UVuIpc2+L730kp5//nmtXr1a77//vlatWqVHH31Uq1atavOYBQUFqq6uNpfy8vJ2jBgAANiJpRWZn/3sZ7rtttvMS0RnnXWW9u7dq8LCQs2aNUvp6emSpMrKSvXv3988rrKyUqNHj25xzPj4eMXHx3d47AAAwHqWVmS+/fbbZp3IHo9HoVBIkjR48GClp6eruLjY3F5TU6OtW7cqOzu7U2MFAAD2Y2lF5vLLL9f999+vrKwsjRgxQv/7v/+rxx9/XNdff72khuti8+bN03333achQ4aYt19nZGRo2rRpVoYOAABswNJE5te//rXuvPNO/eQnP9HBgweVkZGhm266SYsWLTL3+fnPf67a2lrdeOONqqqq0sSJE7Vu3bpWPUMGAAB0bS7DaPIY3S6opqZGXq9X1dXV3LUEAIBDtPb3N3MtAQAAxyKRAQAAjkUiAwAAHItEBgAAOBaJDAAAcCwSGQAA4FgkMgAAoE0Wv7lT1/6/97Ths4OWxUAiAwAA2uTjv9Xo7Z1fqaKmzrIYSGQAAECbBMNzI8a4XZbFQCIDAADaJBBqmBwgxkMiAwAAHCYYTmQ8buvSCRIZAADQJmZFhktLAADAaQLBhh4ZD4kMAABwmiAVGQAA4FSNzb70yAAAAIehIgMAABwrYN61RCIDAAAchooMAABwLD93LQEAAKdqrMjQ7AsAAByGKQoAAIBj0SMDAAAciyf7AgAAxwrQIwMAAJzKfI4MPTIAAMBpIj0ysVxaAgAATmIYhpnI0CMDAAAcJZLESPTIAAAAhwk0SWTokQEAAI4SiKrIkMgAAAAHCQabVGRIZAAAgJMEQiHzeyoyAADAUZreseRykcgAAAAHCdjg1mvJ4kRm0KBBcrlczZb8/HxJUl1dnfLz85WamqpevXopLy9PlZWVVoYMAABkjwkjJYsTmW3btunAgQPm8sYbb0iSrrzySknS/Pnz9frrr2vNmjXatGmT9u/frxkzZlgZMgAAkOS3wYSRkhRj5Zv37ds36vWDDz6o733ve7roootUXV2tFStWaPXq1Zo8ebIkaeXKlRo+fLi2bNmi8847z4qQAQCAqMg0U19fr9/97ne6/vrr5XK5VFpaKr/fr5ycHHOfYcOGKSsrSyUlJcccx+fzqaamJmoBAADty5z52mNtKmGbRGbt2rWqqqrS7NmzJUkVFRWKi4tTcnJy1H5paWmqqKg45jiFhYXyer3mkpmZ2YFRAwDQPVGROcqKFSs0ZcoUZWRknNI4BQUFqq6uNpfy8vJ2ihAAAETY5a4lS3tkIvbu3as333xTr7zyirkuPT1d9fX1qqqqiqrKVFZWKj09/ZhjxcfHKz4+viPDBQCg2wuGH4hHRUYNTbz9+vXTZZddZq4bM2aMYmNjVVxcbK4rKyvTvn37lJ2dbUWYAAAgzB+kIiNJCoVCWrlypWbNmqWYmMZwvF6v5syZowULFiglJUVJSUmaO3eusrOzuWMJAACLNfbIWFsTsTyRefPNN7Vv3z5df/31zbY98cQTcrvdysvLk8/nU25urpYsWWJBlAAAoKnGu5a6eUXmkksukWEYLW5LSEhQUVGRioqKOjkqAABwPPTIAAAAxwrYpEeGRAYAAJy0gE16ZEhkAADASbPLc2RIZAAAwEkze2QsbvYlkQEAACct0iNDsy8AAHCcoHlpiR4ZAADgMAEmjQQAAE4VCDb0yHjokQEAAE5DRQYAADhWkNuvAQCAU0UqMrE0+wIAAKcxKzL0yAAAAKehRwYAADiWedcSiQwAAHCaIBUZAADgVAGe7AsAAJwqUpGJpdkXAAA4TSBEjwwAAHAoemQAAIBj+YP0yAAAAIeiIgMAABzLfCAezb4AAMBpguFmXyoyAADAcQL0yAAAAKdiriUAAOBYjU/2JZEBAAAOY/bI0OwLAACcJtIjE0OPDAAAcJogl5YAAIBT0ewLAAAcy5w0kh4ZAADgNI09MiQyAADAYeiRAQAAjhVJZGI93fyupb/97W/6t3/7N6WmpioxMVFnnXWWtm/fbm43DEOLFi1S//79lZiYqJycHO3atcvCiAEAAA/Ek/T3v/9dF1xwgWJjY/XnP/9Zn376qR577DH16dPH3Ofhhx/Wk08+qWXLlmnr1q3q2bOncnNzVVdXZ2HkAAB0b0Gb3LUUY+WbP/TQQ8rMzNTKlSvNdYMHDza/NwxDixcv1h133KGpU6dKkp599lmlpaVp7dq1+tGPftTpMQMAAMkfDN+11J0rMq+99prGjh2rK6+8Uv369dM555yjp59+2ty+Z88eVVRUKCcnx1zn9Xo1YcIElZSUtDimz+dTTU1N1AIAANpXY0WmG/fI/PWvf9XSpUs1ZMgQrV+/XjfffLNuvfVWrVq1SpJUUVEhSUpLS4s6Li0tzdx2tMLCQnm9XnPJzMzs2A8BAEA3RI+MpFAopHPPPVcPPPCAzjnnHN1444264YYbtGzZsjaPWVBQoOrqanMpLy9vx4gBAIDU9K6lbpzI9O/fX2eeeWbUuuHDh2vfvn2SpPT0dElSZWVl1D6VlZXmtqPFx8crKSkpagEAAO3LfLJvd67IXHDBBSorK4tat3PnTg0cOFBSQ+Nvenq6iouLze01NTXaunWrsrOzOzVWAADQyC49MpbetTR//nydf/75euCBB3TVVVfpvffe0/Lly7V8+XJJksvl0rx583TfffdpyJAhGjx4sO68805lZGRo2rRpVoYOAEC3ZRiG/EF79MhYmsiMGzdOr776qgoKCnTPPfdo8ODBWrx4sWbOnGnu8/Of/1y1tbW68cYbVVVVpYkTJ2rdunVKSEiwMHIAALqvcDFGkvXPkXEZhmGceDfnqqmpkdfrVXV1Nf0yAAC0A18gqKF3rJMkfXjXJUpKiG3392jt72/LpygAAADOEmxSkontzs+RAQAAzhNokshY3SNDIgMAAE5KINiYyFjdI0MiAwAATkrkGTIul+QmkQEAAE5il5mvJRIZAABwkgI2eYaMRCIDAABOkjnPksV3LEkkMgAA4CSZM19bPGGkRCIDAABOUqTZlx4ZAADgOPTIAAAAx7LLzNcSiQwAADhJkR6ZGHpkAACA00QqMlxaAgAAjkOzLwAAcKzGZl/r0wjrIwAAAI7CFAUAAMCxAvTIAAAApwqGe2RiuWsJAAA4DRUZAADgWDwQDwAAOJafKQoAAIBTBXmODAAAcCp6ZAAAgGNFemRiPdanEdZHAAAAHCVAjwwAAHAqnuwLAAAcyx9u9qUiAwAAHCcYvrQUw5N9AQCA03DXEgAAcCye7AsAABwrQLMvAABwqkAw3OxLjwwAAHAaKjIAAMCxgmazr/VphKUR3HXXXXK5XFHLsGHDzO11dXXKz89XamqqevXqpby8PFVWVloYMQAAoCLTxIgRI3TgwAFzeffdd81t8+fP1+uvv641a9Zo06ZN2r9/v2bMmGFhtAAAwJz92gY9MjGWBxATo/T09Gbrq6urtWLFCq1evVqTJ0+WJK1cuVLDhw/Xli1bdN5553V2qAAAQFRkouzatUsZGRk6/fTTNXPmTO3bt0+SVFpaKr/fr5ycHHPfYcOGKSsrSyUlJcccz+fzqaamJmoBAADtp3HSSMvTCGsTmQkTJuiZZ57RunXrtHTpUu3Zs0cXXnihDh8+rIqKCsXFxSk5OTnqmLS0NFVUVBxzzMLCQnm9XnPJzMzs4E8BAED3YqdJIy29tDRlyhTz+1GjRmnChAkaOHCgXnrpJSUmJrZpzIKCAi1YsMB8XVNTQzIDAEA7CjBpZMuSk5N1xhlnaPfu3UpPT1d9fb2qqqqi9qmsrGyxpyYiPj5eSUlJUQsAAGg/dqrI2CqROXLkiD7//HP1799fY8aMUWxsrIqLi83tZWVl2rdvn7Kzsy2MEgCA7s1s9vVYn0ZYemlp4cKFuvzyyzVw4EDt379fv/zlL+XxeHTNNdfI6/Vqzpw5WrBggVJSUpSUlKS5c+cqOzubO5YAALCQnSoyliYyX375pa655hp988036tu3ryZOnKgtW7aob9++kqQnnnhCbrdbeXl58vl8ys3N1ZIlS6wMGQCAbs8ftE+PjKWJzAsvvHDc7QkJCSoqKlJRUVEnRQQAAE7EThUZ6y9uAQAARwmYcy2RyAAAAIcxKzI2mKKARAYAAJyUyJN9Y7r7k30BAIDz0CMDAAAcy8+TfQEAgFPRIwMAAByL2a8BAIBj0SMDAAAcK8ClJQAA4FSBcLMvFRkAAOA4QXpkAACAUwXokQEAAE4VZK4lAADgVGaPDM2+AADASUIhQ+GCDHMtAQAAZ4n0x0hcWgIAAA4TbJLI0OwLAAAcJdIfI1GRAQAADkNFBgAAOBY9MgAAwLGaThjpcpHIAAAAB/EHG3pk7FCNkUhkAADASQjaaHoCiUQGAACchICNpieQSGQAAMBJMCsyHnukEPaIAgAAOEIgyKUlAADgUPTIAAAAx/KHn+zrscHM1xKJDAAAOAmNFRl7pBBtiqK8vFxffvml+fq9997TvHnztHz58nYLDAAA2E+kR8bRdy3967/+qzZs2CBJqqio0A9+8AO99957uv3223XPPfe0a4AAAMA+ukSPzMcff6zx48dLkl566SWNHDlSf/nLX/T888/rmWeeac/4AACAjURmv45xco+M3+9XfHy8JOnNN9/UD3/4Q0nSsGHDdODAgfaLDgAA2ErjpSUH98iMGDFCy5Yt0zvvvKM33nhDl156qSRp//79Sk1NbdcAAQCAfQS6wqWlhx56SL/5zW80adIkXXPNNTr77LMlSa+99pp5yelkPfjgg3K5XJo3b565rq6uTvn5+UpNTVWvXr2Ul5enysrKNo0PAABOXdBmUxTEtOWgSZMm6euvv1ZNTY369Oljrr/xxhvVo0ePkx5v27Zt+s1vfqNRo0ZFrZ8/f77++Mc/as2aNfJ6vbrllls0Y8YMbd68uS1hAwCAU2T2yNgkkWlTRea7776Tz+czk5i9e/dq8eLFKisrU79+/U5qrCNHjmjmzJl6+umno5Ki6upqrVixQo8//rgmT56sMWPGaOXKlfrLX/6iLVu2tCVsAABwiuxWkWlTIjN16lQ9++yzkqSqqipNmDBBjz32mKZNm6alS5ee1Fj5+fm67LLLlJOTE7W+tLRUfr8/av2wYcOUlZWlkpKSY47n8/lUU1MTtQAAgPYR6ZGJdfKkke+//74uvPBCSdLLL7+stLQ07d27V88++6yefPLJVo/zwgsv6P3331dhYWGzbRUVFYqLi1NycnLU+rS0NFVUVBxzzMLCQnm9XnPJzMxsdTwAAOD4usQD8b799lv17t1bkvQ///M/mjFjhtxut8477zzt3bu3VWOUl5frpz/9qZ5//nklJCS0JYwWFRQUqLq62lzKy8vbbWwAALq7YFfokfn+97+vtWvXqry8XOvXr9cll1wiSTp48KCSkpJaNUZpaakOHjyoc889VzExMYqJidGmTZv05JNPKiYmRmlpaaqvr1dVVVXUcZWVlUpPTz/muPHx8UpKSopaAABA+wh0hR6ZRYsWaeHChRo0aJDGjx+v7OxsSQ3VmXPOOadVY/zTP/2TPvroI33wwQfmMnbsWM2cOdP8PjY2VsXFxeYxZWVl2rdvn/l+AACgc9ltioI23X59xRVXaOLEiTpw4ID5DBmpITmZPn16q8bo3bu3Ro4cGbWuZ8+eSk1NNdfPmTNHCxYsUEpKipKSkjR37lxlZ2frvPPOa0vYAADgFDVWZOzR7NumREaS0tPTlZ6ebs6CPWDAgDY/DO9YnnjiCbndbuXl5cnn8yk3N1dLlixp1/cAAACtFzTvWrJHRaZN6VQoFNI999wjr9ergQMHauDAgUpOTta9996rULgJqC02btyoxYsXm68TEhJUVFSkQ4cOqba2Vq+88spx+2MAAEDH8gcbfs/bpUemTRWZ22+/XStWrNCDDz6oCy64QJL07rvv6q677lJdXZ3uv//+dg0SAADYQ5fokVm1apV++9vfmrNeS9KoUaP0D//wD/rJT35CIgMAQBdltx6ZNkVx6NAhDRs2rNn6YcOG6dChQ6ccFAAAsCezIuPkHpmzzz5bTz31VLP1Tz31VLOJHwEAQNcRebKvoy8tPfzww7rsssv05ptvms90KSkpUXl5uf70pz+1a4AAAMA+usSTfS+66CLt3LlT06dPV1VVlaqqqjRjxgx98skneu6559o7RgAAYBN+m/XItPk5MhkZGc2aenfs2KEVK1Zo+fLlpxwYAACwn2CwC/TIAACA7qlLzLUEAAC6py7RIwMAALqngJMfiDdjxozjbq+qqjqVWAAAgM1Fbr/2eOxRCzmpRMbr9Z5w+7XXXntKAQEAAPtydEVm5cqVHRUHAABwgEiPDM2+AADAcexWkSGRAQAArRbk9msAAOBUkYpMrE2afe0RBQAAcIRAkB4ZAADgUEF6ZAAAgFMxRQEAAHCsxoqMPVIIe0QBAAAcgYoMAABwrKB51xKJDAAAcBg/dy0BAACnokcGAAA4Fj0yAADAscyKDD0yAADAaXiyLwAAcCzzriV6ZAAAgNP4Iz0yXFoCAABOw1xLAADAkQzDMBMZemQAAICjRJIYiYoMAABwmECTRIaKDAAAcJSmiUysxx4phKVRLF26VKNGjVJSUpKSkpKUnZ2tP//5z+b2uro65efnKzU1Vb169VJeXp4qKystjBgAgO4rGKQiE2XAgAF68MEHVVpaqu3bt2vy5MmaOnWqPvnkE0nS/Pnz9frrr2vNmjXatGmT9u/frxkzZlgZMgAA3VYgFDK/97jskci4DMMwTrxb50lJSdEjjzyiK664Qn379tXq1at1xRVXSJI+++wzDR8+XCUlJTrvvPNaNV5NTY28Xq+qq6uVlJTUkaEDANClHayp0/gHiuV2SX8tvKxD36u1v7/tcYFLUjAY1AsvvKDa2lplZ2ertLRUfr9fOTk55j7Dhg1TVlaWSkpKjjmOz+dTTU1N1AIAAE5dwGYzX0s2SGQ++ugj9erVS/Hx8fr3f/93vfrqqzrzzDNVUVGhuLg4JScnR+2flpamioqKY45XWFgor9drLpmZmR38CQAA6B7sNmGkZINEZujQofrggw+0detW3XzzzZo1a5Y+/fTTNo9XUFCg6upqcykvL2/HaAEA6L78NpswUpJirA4gLi5O3//+9yVJY8aM0bZt2/SrX/1KV199terr61VVVRVVlamsrFR6evoxx4uPj1d8fHxHhw0AQLdjt+kJJBtUZI4WCoXk8/k0ZswYxcbGqri42NxWVlamffv2KTs728IIAQDongLm9AT2SR8srcgUFBRoypQpysrK0uHDh7V69Wpt3LhR69evl9fr1Zw5c7RgwQKlpKQoKSlJc+fOVXZ2dqvvWAIAAO3HjhUZSxOZgwcP6tprr9WBAwfk9Xo1atQorV+/Xj/4wQ8kSU888YTcbrfy8vLk8/mUm5urJUuWWBkyAADdVsBmE0ZKFicyK1asOO72hIQEFRUVqaioqJMiAgAAxxIMPxAvlruWAACA0/iD9qvIkMgAAIBWCfJAPAAA4FR27JEhkQEAAK0S6ZHhyb4AAMBxAvTIAAAAp4r0yMTSIwMAAJzGT48MAABwKnpkAACAY9EjAwAAHMuOcy2RyAAAgFbhOTIAAMCxAsFIj4x90gf7RAIAAGwtwKUlAADgVEEuLQEAAKeiIgMAAByrsSJjn/TBPpEAAABboyIDAAAcq/GuJRIZAADgMDwQDwAAOFaAHhkAAOBUVGQAAIBjBcKzX/McGQAA4DhUZAAAgGP5g+FEhrmWAACA01CRAQAAjhVgriUAAOBUwRAPxAMAAA4VCFKRAQAADkWPDAAAcCy/mcjYJ32wTyQAAMDW6JEBAACORY8MAABwLHpkjlJYWKhx48apd+/e6tevn6ZNm6aysrKoferq6pSfn6/U1FT16tVLeXl5qqystChiAAC6L2a/PsqmTZuUn5+vLVu26I033pDf79cll1yi2tpac5/58+fr9ddf15o1a7Rp0ybt379fM2bMsDBqAAC6p4ANe2RirHzzdevWRb1+5pln1K9fP5WWluof//EfVV1drRUrVmj16tWaPHmyJGnlypUaPny4tmzZovPOO8+KsAEA6JYiPTJcWjqG6upqSVJKSookqbS0VH6/Xzk5OeY+w4YNU1ZWlkpKSiyJEQCA7ipowykKLK3INBUKhTRv3jxdcMEFGjlypCSpoqJCcXFxSk5Ojto3LS1NFRUVLY7j8/nk8/nM1zU1NR0WMwAA3UmQ58gcW35+vj7++GO98MILpzROYWGhvF6vuWRmZrZThAAAdG9MGnkMt9xyi/77v/9bGzZs0IABA8z16enpqq+vV1VVVdT+lZWVSk9Pb3GsgoICVVdXm0t5eXlHhg4AQLfB7ddHMQxDt9xyi1599VW99dZbGjx4cNT2MWPGKDY2VsXFxea6srIy7du3T9nZ2S2OGR8fr6SkpKgFAACcOn+Qu5ai5Ofna/Xq1frDH/6g3r17m30vXq9XiYmJ8nq9mjNnjhYsWKCUlBQlJSVp7ty5ys7O5o4lAAA6mR17ZCxNZJYuXSpJmjRpUtT6lStXavbs2ZKkJ554Qm63W3l5efL5fMrNzdWSJUs6OVIAAGDHHhlLExnDME64T0JCgoqKilRUVNQJEQEAgGOhRwYAADhW5Mm+dqrIkMgAAIBWMSsyNmr2JZEBAAAnZBiG/EH7NfvaJxIAAGBboSZtrfTIAAAAR4n0x0iSh0tLAADASYJNSjJUZAAAgKMEmiQy3LUEAAAcJRhsWpGxT/pgn0gAAIBt+cM9Mi4XFRkAAOAwdnyqr0QiAwAAWiEQtN88SxKJDAAAaAU7znwtkcgAAIBWsOPM1xKJDAAAaAV6ZAAAgGP5gw13LdlpwkiJRAYAALQCPTIAAMCx6JEBAACORY8MAABwrMjs11RkAACA4/BAPAAA4FiRS0uxHnulDvaKBgAA2BLNvgAAwLGC4R4Zmn0BAIDjUJEBAACOZd5+zZN9AQCA0/jNu5bslTrYKxoAAGBLkR6ZWC4tAQAAp6FHBgAAOBY9MgAAwLEC9MgAAACnYtJIAADgWH4eiAcAAJwqGKRHBgAAOBR3LbXg7bff1uWXX66MjAy5XC6tXbs2arthGFq0aJH69++vxMRE5eTkaNeuXdYECwBAN9bYI2OvGoil0dTW1urss89WUVFRi9sffvhhPfnkk1q2bJm2bt2qnj17Kjc3V3V1dZ0cKQAA3ZtdKzIxVr75lClTNGXKlBa3GYahxYsX64477tDUqVMlSc8++6zS0tK0du1a/ehHP+rMUAEA6NaY/fok7dmzRxUVFcrJyTHXeb1eTZgwQSUlJRZGBgBA9+O3abOvpRWZ46moqJAkpaWlRa1PS0szt7XE5/PJ5/OZr2tqajomQAAAupFgiAfidYrCwkJ5vV5zyczMtDokAAAcL8AD8U5Oenq6JKmysjJqfWVlpbmtJQUFBaqurjaX8vLyDo0TAIDuINIjY7dmX9smMoMHD1Z6erqKi4vNdTU1Ndq6dauys7OPeVx8fLySkpKiFgAAcGrsWpGxtEfmyJEj2r17t/l6z549+uCDD5SSkqKsrCzNmzdP9913n4YMGaLBgwfrzjvvVEZGhqZNm2Zd0AAAdEONk0aSyJi2b9+uiy++2Hy9YMECSdKsWbP0zDPP6Oc//7lqa2t14403qqqqShMnTtS6deuUkJBgVcgAAHRLkWbfWI+9LuZYmshMmjRJhmEcc7vL5dI999yje+65pxOjAgAARwvQIwMAAJwqaNMeGRIZAABwQnadooBEBgAAnJBZkbHZk31JZAAAwAn5g5EeGXulDvaKBgAA2JJ51xKXlgAAgNPQIwMAAByLHhkAAOBYjU/2tVfqYK9oAACALfEcGQAA4Fh+nuwLAACcqnGuJRIZAADgMPTIAAAAx6JHBgAAOBbPkQEAAI701WGfan0BSVRkAACAg5Qf+lZXLPuLvvMH1d+boMyUHlaHFCXG6gAAAIA9lVUc1o9XbNXBwz4N6JOo382ZoIRYj9VhRSGRAQAAzZTuPaTrVm5TTV1AQ9N669k545WWlGB1WM2QyAAAgCgbyw7q339Xqjp/SOdmJev/zR6n5B5xVofVIhIZAO3LMKRvvpGOHJF69ZJSUyWXvZoDAbRsf9V3eux/duqV//1ShiFddEZfLf23c9Ujzr7pgn0jA+AsVVXSqlXSr38tff554/rvfU+aO1eaNUtKTrYqOgDHUf2tX0s27dbKzV+oPtAwFcFVYwfovmlnKS7G3vcFuQzDMKwOoiPV1NTI6/WqurpaSUlJVocDdE3r10t5edK33za8bvrPSqQa06OH9PvfS7m5nR8fgBbV+YN6rmSvntqwW9Xf+SVJEwanqOCfh2t0ZrKlsbX29zcVmTa667VP9NL2csW4XYqLcSvWE1lczb6Pi3Erxu2Sx92wzuN2mdsTYz1KiPM0fI31KM7jVjBkqD4YUn0gJH+wYTGMxt8HLpdLrvDXGLdLbrdLHpdLMeGxY8Ljx3hcinW75XE3rHe7XXK7JI+r4XuXJLfLJZdL4aVxnMhniHGHx3W5GraHx3C7Gtc3vKdbniavI/u5uKTQ9a1fL112WUPy0tL/iyLrvvuuYb8//pFkBrBYnT+o/3pvn5Zu/FwHD/skSWek9dJtU4bp4qH9HPVvN4lMG9X5g/q2Pmh1GLYXSWpiwsmWxxP+Ppz8RJKvSMIV63EpJpxARSWGMW7FhV/HeNyKDSeGTZO3SIIYed2QhLmjvjbdN8YTjuGo10dva7q+aSIaGctJP/DtrqqqoRJjGFJ4ZtxjCoUkt7th/y+/5DITYIGWEph/SE7UT/9piPLGDLDdU3tbg0SmjW6bMkz5F39f9eGKiT9gyB8KyR8IyR805A+GVB8MKRA0VB8MKhA0FAgZCgRD4a8NVZc6f1B1/qC+8wf1XX3DMbGRX+IxjZWRyF8tQ5JhGA2/NwwpZBgKhEIKhqRgqGHsYHh8f+S9QoZCIUMho2Fb5GvDWOHxpPB6mTHWB0LhsQ2FDJnHhkKGgoahUEgKhEIKHefiZDAcT33HnxLLRJI1T5MKWeRrbDiJiiRkkWpXXLhSF0nW4mI8DV89Tap7MS7Fh7+Pi2ncP85M6sLjHv262dhuxXs84XUN8bSbVasaLie19gp1KNSw/7PPSrfe2n5xADgmwzD0yf4arfu4Qi9tL49KYPIv/r6uGDPA9n0wx0Mi00bJPeJseytaZzPCiVEkiQoahoLB8NdQ49KwvTGRi7z2Bxu2+8OJXyC8LhCKThAD4Utu/oARnbSFGo9v+NpwbCRxbLpPoMm2yOtgqOE9mm6PHBNJCCOvWxL5fE7hcslMpppeCo0kOubro5LpqP1j3Ipzu3Trg4+pjyGdzP/hDEn1jy/WX6+YpdgYzzHfx9Pdq13AKQiGDJXu/bvWfVyh9Z9U6G9V35nbukoCE0Eig1PmCvfLxNjrYY/trmnCFgg1JGuRBMgfDDWpUIUrW6GjErNwUhSp1vnDfVD1wXD1y9xmmNsi+/sC4X3DCV1kXWS/wFFVwMh6X/hr9OeQOdap6PNtte6qKD/p41yGofi9e3TNQ39SVeKxG/iaJlxNK1nm5UWPW3Ge6IpXbNPtZq/Y0QladF+bWeGKcSnO44mqnsW4G8eL7odrHK/pJcfIJUigsxiGoSO+gD7/qlaf7K/Wp/tr9Mn+Gn1WUaM6f+PPeGKsR5OG9tWlI9M1ZWT/LpHARJDIAK3k1ITNMIzGy52BpolU47pIcuQPhOQPNSZWTfdrTLwaqmKJf9t3SnFlxgRlJMaalzD9weiqVnslXJ3N5VK4Ryvcf3VUL1bTPq3W9GM17+9q3u/lcUdf1jR7w9wuecJJnfuo/Zs25keSMLdb4fcPN/S3sK7pMW6Xy1zndsm88aDpe7i7YWIX+c9N0ypxdNU4/J+TYON/UJr+bNaHWxR8gYZezFpfwPxaWx/Q32v9+vu39TpUW6+qb/2qD7b8M5KUEKOcM9N06Yh0XTikrxLjHPaPVyuRyABdnMvlUlxMQzWhZ3w7Dvx1n1M6/PX/nNLwsLwwwzDMhKo+EJ1ARf7hj1S46gNNt0X6wVq+FBkIGuGk7KhkLlyxChyV1DWtakUuT/qbvG76vi1dUjQMhffnZoCISFLjcjUmQy6XopIdjyv6TsfGuzRlJkyN30suucy7LSN3UjZ9Hbkr0+1qeNHsLk1Fv0+EEe49NAzJUEN/oJp8bxiGgobk8wflC4Qav5p/d47fN9hRUnvG6cyMJI3I8Ia/JmlQas9uUSEkkQHQNqmpDQ+7++tfW9/sKzX85jj9dCkl5ajVLvOSjVPaz0JNe7RCoajLjZH+q0CTPrCofq3wpcfW9o9FHxvu/2rSjxZqYbs/fAk00vcVNBpuCmjauxbpa4vEcnRvW9N9gqHGZv+mNwGcSDBkKCjn9JF1hKYVt0gTfuSyZlyMJ9yUf1S/WIxbPeM86hkfo55xMeoR71HPuBgl94hVSs849ekRpz4949SnR6ytn7zb0brvJwdwalyuhif2zp9/8sfeemuXmLbA7XYpLvw/3kR1zbJ9a0QnN5GER1Hrjl7f9A7IyPqQYZh3Zkoy76xUuBoSCjVWRSJ3WkYqKGYVxYgeJxReZ6hx34jI+sj3kYqP1Pi8rsZKkKRwFcjjcikh1qP4WLfiY9yKj/EoPqahCb7ZnYpuGtc7Gk/2BdB2VVXSgAEND7s70XNkpIbnyCQm8hwZACfU2t/fXadtGUDnS05umHbA5WpIUo7H7W7Y75VXSGIAtBsSGQCnJje3YdqBxESZnZRNRdYlJkp/+pN0ySXWxAmgS3JEIlNUVKRBgwYpISFBEyZM0HvvvWd1SACays1tuFy0eHFDI29Tp5/esP5vfyOJAdDubN8j8+KLL+raa6/VsmXLNGHCBC1evFhr1qxRWVmZ+vXrd8Lj6ZEBOplhSIcOSYcPS717N9ydRKMjgJPU2t/ftk9kJkyYoHHjxumpp56SJIVCIWVmZmru3Lm67bbbTng8iQwAAM7TJZp96+vrVVpaqpycHHOd2+1WTk6OSkpKLIwMAADYga2fI/P1118rGAwqLS0tan1aWpo+++yzFo/x+Xzy+Xzm65qamg6NEQAAWMfWFZm2KCwslNfrNZfMzEyrQwIAAB3E1onMaaedJo/Ho8rKyqj1lZWVSk9Pb/GYgoICVVdXm0t5+cnPzgsAAJzB1olMXFycxowZo+LiYnNdKBRScXGxsrOzWzwmPj5eSUlJUQsAAOiabN0jI0kLFizQrFmzNHbsWI0fP16LFy9WbW2trrvuOqtDAwAAFrN9InP11Vfrq6++0qJFi1RRUaHRo0dr3bp1zRqAAQBA92P758icKp4jAwCA87T297ftKzKnKpKncRs2AADOEfm9faJ6S5dPZA4fPixJ3IYNAIADHT58WF6v95jbu/ylpVAopP3796t3795yHTXfy7hx47Rt27Zmx7RmfU1NjTIzM1VeXm7ZJatjxdlZ47T2uBPtd7ztnKO2j3Myx3TWOTp6HeeofX6GTrQP56jt43COTqy9zs/RYxmGocOHDysjI0Nu97Fvsu7yFRm3260BAwa0uM3j8bR40k9mvZW3eB8rzs4ap7XHnWi/423nHLV9nJM5prPO0bH25Ryd+n6co44Zh3N0Yu11floa63iVmAhbP0emo+Xn57fLequ0VzxtHae1x51ov+Nt5xy1fZyTOaazzpHdzo/kjHPUmv04Rx0zDufoxNoznraM1eUvLXUU7oayP86R/XGO7I9zZH/d/Rx164rMqYiPj9cvf/lLxcfHWx0KjoFzZH+cI/vjHNlfdz9HVGQAAIBjUZEBAACORSIDAAAci0QGAAA4FokMAABwLBIZAADgWCQynWTQoEEaNWqURo8erYsvvtjqcHAM3377rQYOHKiFCxdaHQqOUlVVpbFjx2r06NEaOXKknn76aatDQhPl5eWaNGmSzjzzTI0aNUpr1qyxOiS0YPr06erTp4+uuOIKq0NpN9x+3UkGDRqkjz/+WL169bI6FBzH7bffrt27dyszM1OPPvqo1eGgiWAwKJ/Ppx49eqi2tlYjR47U9u3blZqaanVokHTgwAFVVlZq9OjRqqio0JgxY7Rz50717NnT6tDQxMaNG3X48GGtWrVKL7/8stXhtAsqMkDYrl279Nlnn2nKlClWh4IWeDwe9ejRQ5Lk8/lkGIb4f5h99O/fX6NHj5Ykpaen67TTTtOhQ4esDQrNTJo0Sb1797Y6jHZFIiPp7bff1uWXX66MjAy5XC6tXbu22T5FRUUaNGiQEhISNGHCBL333nsn9R4ul0sXXXSRxo0bp+eff76dIu8+OuMcLVy4UIWFhe0UcffTGeeoqqpKZ599tgYMGKCf/exnOu2009op+q6vM85PRGlpqYLBoDIzM08x6u6lM89RV0IiI6m2tlZnn322ioqKWtz+4osvasGCBfrlL3+p999/X2effbZyc3N18OBBc5/Idfujl/3790uS3n33XZWWluq1117TAw88oA8//LBTPltX0dHn6A9/+IPOOOMMnXHGGZ31kbqczvg5Sk5O1o4dO7Rnzx6tXr1alZWVnfLZuoLOOD+SdOjQIV177bVavnx5h3+mrqazzlGXYyCKJOPVV1+NWjd+/HgjPz/ffB0MBo2MjAyjsLCwTe+xcOFCY+XKlacQZffWEefotttuMwYMGGAMHDjQSE1NNZKSkoy77767PcPuVjrj5+jmm2821qxZcyphdlsddX7q6uqMCy+80Hj22WfbK9RuqyN/hjZs2GDk5eW1R5i2QEXmBOrr61VaWqqcnBxzndvtVk5OjkpKSlo1Rm1trQ4fPixJOnLkiN566y2NGDGiQ+LtjtrjHBUWFqq8vFxffPGFHn30Ud1www1atGhRR4Xc7bTHOaqsrDR/jqqrq/X2229r6NChHRJvd9Me58cwDM2ePVuTJ0/Wj3/8444Ktdtqj3PUVcVYHYDdff311woGg0pLS4tan5aWps8++6xVY1RWVmr69OmSGu68uOGGGzRu3Lh2j7W7ao9zhI7VHudo7969uvHGG80m37lz5+qss87qiHC7nfY4P5s3b9aLL76oUaNGmb0dzz33HOeonbTXv3M5OTnasWOHamtrNWDAAK1Zs0bZ2dntHW6nIpHpBKeffrp27NhhdRhopdmzZ1sdAlowfvx4ffDBB1aHgWOYOHGiQqGQ1WHgBN58802rQ2h3XFo6gdNOO00ej6dZU2FlZaXS09MtigpNcY7sj3Nkb5wf++McHRuJzAnExcVpzJgxKi4uNteFQiEVFxc7vhzXVXCO7I9zZG+cH/vjHB0bl5bU0IC7e/du8/WePXv0wQcfKCUlRVlZWVqwYIFmzZqlsWPHavz48Vq8eLFqa2t13XXXWRh198I5sj/Okb1xfuyPc9RGFt81ZQsbNmwwJDVbZs2aZe7z61//2sjKyjLi4uKM8ePHG1u2bLEu4G6Ic2R/nCN74/zYH+eobZhrCQAAOBY9MgAAwLFIZAAAgGORyAAAAMcikQEAAI5FIgMAAByLRAYAADgWiQwAAHAsEhkAAOBYJDIAbG/QoEFavHix1WEAsCGe7AtAkjR79mxVVVVp7dq1VofSzFdffaWePXuqR48eVofSIjv/2QFdHRUZAJbx+/2t2q9v376WJDGtjQ+AdUhkALTKxx9/rClTpqhXr15KS0vTj3/8Y3399dfm9nXr1mnixIlKTk5Wamqq/uVf/kWff/65uf2LL76Qy+XSiy++qIsuukgJCQl6/vnnNXv2bE2bNk2PPvqo+vfvr9TUVOXn50clEUdfWnK5XPrtb3+r6dOnq0ePHhoyZIhee+21qHhfe+01DRkyRAkJCbr44ou1atUquVwuVVVVHfMzulwuLV26VD/84Q/Vs2dP3X///QoGg5ozZ44GDx6sxMREDR06VL/61a/MY+666y6tWrVKf/jDH+RyueRyubRx40ZJUnl5ua666iolJycrJSVFU6dO1RdffNG2EwCgRSQyAE6oqqpKkydP1jnnnKPt27dr3bp1qqys1FVXXWXuU1tbqwULFmj79u0qLi6W2+3W9OnTFQqFosa67bbb9NOf/lT/93//p9zcXEnShg0b9Pnnn2vDhg1atWqVnnnmGT3zzDPHjenuu+/WVVddpQ8//FD//M//rJkzZ+rQoUOSpD179uiKK67QtGnTtGPHDt100026/fbbW/VZ77rrLk2fPl0fffSRrr/+eoVCIQ0YMEBr1qzRp59+qkWLFuk///M/9dJLL0mSFi5cqKuuukqXXnqpDhw4oAMHDuj888+X3+9Xbm6uevfurXfeeUebN29Wr169dOmll6q+vr61f/QATsTaybcB2MWsWbOMqVOntrjt3nvvNS655JKodeXl5YYko6ysrMVjvvrqK0OS8dFHHxmGYRh79uwxJBmLFy9u9r4DBw40AoGAue7KK680rr76avP1wIEDjSeeeMJ8Lcm44447zNdHjhwxJBl//vOfDcMwjF/84hfGyJEjo97n9ttvNyQZf//731v+AwiPO2/evGNuj8jPzzfy8vKiPsPRf3bPPfecMXToUCMUCpnrfD6fkZiYaKxfv/6E7wGgdajIADihHTt2aMOGDerVq5e5DBs2TJLMy0e7du3SNddco9NPP11JSUkaNGiQJGnfvn1RY40dO7bZ+CNGjJDH4zFf9+/fXwcPHjxuTKNGjTK/79mzp5KSksxjysrKNG7cuKj9x48f36rP2lJ8RUVFGjNmjPr27atevXpp+fLlzT7X0Xbs2KHdu3erd+/e5p9ZSkqK6urqoi65ATg1MVYHAMD+jhw5ossvv1wPPfRQs239+/eXJF1++eUaOHCgnn76aWVkZCgUCmnkyJHNLqP07Nmz2RixsbFRr10uV7NLUu1xTGscHd8LL7yghQsX6rHHHlN2drZ69+6tRx55RFu3bj3uOEeOHNGYMWP0/PPPN9vWt2/fU44TQAMSGQAndO655+r3v/+9Bg0apJiY5v9sfPPNNyorK9PTTz+tCy+8UJL07rvvdnaYpqFDh+pPf/pT1Lpt27a1aazNmzfr/PPP109+8hNz3dEVlbi4OAWDwah15557rl588UX169dPSUlJbXpvACfGpSUApurqan3wwQdRS3l5ufLz83Xo0CFdc8012rZtmz7//HOtX79e1113nYLBoPr06aPU1FQtX75cu3fv1ltvvaUFCxZY9jluuukmffbZZ/rFL36hnTt36qWXXjKbh10u10mNNWTIEG3fvl3r16/Xzp07deeddzZLigYNGqQPP/xQZWVl+vrrr+X3+zVz5kyddtppmjp1qt555x3t2bNHGzdu1K233qovv/yyvT4q0O2RyAAwbdy4Ueecc07UcvfddysjI0ObN29WMBjUJZdcorPOOkvz5s1TcnKy3G633G63XnjhBZWWlmrkyJGaP3++HnnkEcs+x+DBg/Xyyy/rlVde0ahRo7R06VLzrqX4+PiTGuumm27SjBkzdPXVV2vChAn65ptvoqozknTDDTdo6NChGjt2rPr27avNmzerR48eevvtt5WVlaUZM2Zo+PDhmjNnjurq6qjQAO2IJ/sC6Bbuv/9+LVu2TOXl5VaHAqAd0SMDoEtasmSJxo0bp9TUVG3evFmPPPKIbrnlFqvDAtDOSGQAdEm7du3Sfffdp0OHDikrK0v/8R//oYKCAqvDAtDOuLQEAAAci2ZfAADgWCQyAADAsUhkAACAY5HIAAAAxyKRAQAAjkUiAwAAHItEBgAAOBaJDAAAcCwSGQAA4Fj/Hx7smltLpAd7AAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "
\n",
+ "\n",
+ "Fundamentally, [Fine-Tuning Scheduler](https://finetuning-scheduler.readthedocs.io/en/stable/index.html) enables\n",
+ "scheduled, multi-phase, fine-tuning of foundation models. Gradual unfreezing (i.e. thawing) can help maximize\n",
+ "foundation model knowledge retention while allowing (typically upper layers of) the model to\n",
+ "optimally adapt to new tasks during transfer learning [1, 2, 3]\n",
+ "\n",
+ "
\n",
+ "\n",
+ "The [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) callback orchestrates the gradual unfreezing\n",
+ "of models via a fine-tuning schedule that is either implicitly generated (the default) or explicitly provided by the user\n",
+ "(more computationally efficient). Fine-tuning phase transitions are driven by\n",
+ "[FTSEarlyStopping](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSEarlyStopping)\n",
+ "criteria (a multi-phase extension of ``EarlyStopping`` packaged with FinetuningScheduler), user-specified epoch transitions or a composition of the two (the default mode).\n",
+ "A [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) training session completes when the\n",
+ "final phase of the schedule has its stopping criteria met. See\n",
+ "the [early stopping documentation](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html) for more details on that callback's configuration.\n",
+ "\n",
+ "{height=\"272px\" width=\"376px\"}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "07db61a0",
+ "metadata": {
+ "papermill": {
+ "duration": 0.014737,
+ "end_time": "2023-10-04T01:00:31.527404",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.512667",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "\n",
+ "## Basic Usage\n",
+ "\n",
+ "If no fine-tuning schedule is provided by the user, [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) will generate a\n",
+ "[default schedule](#The-Default-Fine-Tuning-Schedule) and proceed to fine-tune according to the generated schedule,\n",
+ "using default [FTSEarlyStopping](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSEarlyStopping) and [FTSCheckpoint](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSCheckpoint) callbacks with ``monitor=val_loss``.\n",
+ "\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "import lightning as L\n",
+ "from finetuning_scheduler import FinetuningScheduler\n",
+ "trainer = L.Trainer(callbacks=[FinetuningScheduler()])\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7deee3b5",
+ "metadata": {
+ "papermill": {
+ "duration": 0.013461,
+ "end_time": "2023-10-04T01:00:31.593130",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.579669",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## The Default Fine-Tuning Schedule\n",
+ "\n",
+ "Schedule definition is facilitated via the [gen_ft_schedule](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.ScheduleImplMixin.gen_ft_schedule) method which dumps a default fine-tuning schedule (by default using a naive, 2-parameters per level heuristic) which can be adjusted as\n",
+ "desired by the user and/or subsequently passed to the callback. Using the default/implicitly generated schedule will likely be less computationally efficient than a user-defined fine-tuning schedule but is useful for exploring a model's fine-tuning behavior and can serve as a good baseline for subsequent explicit schedule refinement.\n",
+ "While the current version of [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) only supports single optimizer and (optional) lr_scheduler configurations, per-phase maximum learning rates can be set as demonstrated in the next section."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bc71b2b5",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012689,
+ "end_time": "2023-10-04T01:00:31.619528",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.606839",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Specifying a Fine-Tuning Schedule\n",
+ "\n",
+ "To specify a fine-tuning schedule, it's convenient to first generate the default schedule and then alter the thawed/unfrozen parameter groups associated with each fine-tuning phase as desired. Fine-tuning phases are zero-indexed and executed in ascending order.\n",
+ "\n",
+ "1. First, generate the default schedule to ``Trainer.log_dir``. It will be named after your\n",
+ " ``LightningModule`` subclass with the suffix ``_ft_schedule.yaml``.\n",
+ "\n",
+ "```python\n",
+ " import lightning as L\n",
+ " from finetuning_scheduler import FinetuningScheduler\n",
+ " trainer = L.Trainer(callbacks=[FinetuningScheduler(gen_ft_sched_only=True)])\n",
+ "```\n",
+ "\n",
+ "2. Alter the schedule as desired.\n",
+ "\n",
+ "{height=\"327px\" width=\"800px\"}\n",
+ "\n",
+ "3. Once the fine-tuning schedule has been altered as desired, pass it to\n",
+ " [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) to commence scheduled training:\n",
+ "\n",
+ "```python\n",
+ "import lightning as L\n",
+ "from finetuning_scheduler import FinetuningScheduler\n",
+ "\n",
+ "trainer = L.Trainer(callbacks=[FinetuningScheduler(ft_schedule=\"/path/to/my/schedule/my_schedule.yaml\")])\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "62af6cce",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012779,
+ "end_time": "2023-10-04T01:00:31.645053",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.632274",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Early-Stopping and Epoch-Driven Phase Transition Criteria\n",
+ "\n",
+ "\n",
+ "By default, [FTSEarlyStopping](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSEarlyStopping) and epoch-driven\n",
+ "transition criteria are composed. If a ``max_transition_epoch`` is specified for a given phase, the next fine-tuning phase will begin at that epoch unless [FTSEarlyStopping](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSEarlyStopping) criteria are met first.\n",
+ "If [FinetuningScheduler.epoch_transitions_only](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler.params.epoch_transitions_only) is ``True``, [FTSEarlyStopping](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSEarlyStopping) will not be used\n",
+ "and transitions will be exclusively epoch-driven.\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n",
+ "**Tip:** Use of regex expressions can be convenient for specifying more complex schedules. Also, a per-phase base maximum lr can be specified:\n",
+ "\n",
+ "{height=\"380px\" width=\"800px\"}\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "The end-to-end example in this notebook ([Scheduled Fine-Tuning For SuperGLUE](#Scheduled-Fine-Tuning-For-SuperGLUE)) uses [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) in explicit mode to fine-tune a small foundation model on the [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task of [SuperGLUE](https://super.gluebenchmark.com/).\n",
+ "Please see the [official Fine-Tuning Scheduler documentation](https://finetuning-scheduler.readthedocs.io/en/stable/index.html) if you are interested in a similar [CLI-based example](https://finetuning-scheduler.readthedocs.io/en/stable/index.html#example-scheduled-fine-tuning-for-superglue) using the LightningCLI."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "06c0d293",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012695,
+ "end_time": "2023-10-04T01:00:31.670485",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.657790",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Resuming Scheduled Fine-Tuning Training Sessions\n",
+ "\n",
+ "Resumption of scheduled fine-tuning training is identical to the continuation of\n",
+ "[other training sessions](https://lightning.ai/docs/pytorch/stable/common/trainer.html) with the caveat that the provided checkpoint must have been saved by a [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) session.\n",
+ "[FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) uses [FTSCheckpoint](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSCheckpoint) (an extension of ``ModelCheckpoint``) to maintain schedule state with special metadata.\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "import lightning as L\n",
+ "from finetuning_scheduler import FinetuningScheduler\n",
+ "trainer = L.Trainer(callbacks=[FinetuningScheduler()])\n",
+ "trainer.ckpt_path=\"some/path/to/my_checkpoint.ckpt\"\n",
+ "trainer.fit(...)\n",
+ "```\n",
+ "\n",
+ "Training will resume at the depth/level of the provided checkpoint according to the specified schedule. Schedules can be altered between training sessions but schedule compatibility is left to the user for maximal flexibility. If executing a user-defined schedule, typically the same schedule should be provided for the original and resumed training sessions.\n",
+ "\n",
+ "By default ([FinetuningScheduler.restore_best](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html?highlight=restore_best#finetuning_scheduler.fts.FinetuningScheduler.params.restore_best) is ``True``), [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) will attempt to restore the best available checkpoint before fine-tuning depth transitions.\n",
+ "\n",
+ "```python\n",
+ "trainer = L.Trainer(callbacks=[FinetuningScheduler()])\n",
+ "trainer.ckpt_path=\"some/path/to/my_kth_best_checkpoint.ckpt\"\n",
+ "trainer.fit(...)\n",
+ "```\n",
+ "\n",
+ "Note that similar to the behavior of [ModelCheckpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html), when resuming training with a\n",
+ "different [FTSCheckpoint](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSCheckpoint) ``dirpath`` from the provided\n",
+ "checkpoint, the new training session's checkpoint state will be re-initialized at the resumption depth with the provided checkpoint being set as the best checkpoint."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9fd4b1a8",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012728,
+ "end_time": "2023-10-04T01:00:31.696107",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.683379",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "**Note:** Currently, [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) supports the following distributed strategy types:\n",
+ "\n",
+ "- ``ddp`` (and aliases ``ddp_find_unused_parameters_false``, ``ddp_find_unused_parameters_true``, ``ddp_spawn``, ``ddp_fork``, ``ddp_notebook``)\n",
+ "- ``fsdp`` (and alias ``fsdp_cpu_offload``)\n",
+ "\n",
+ "Custom or officially unsupported strategies can be used by setting [FinetuningScheduler.allow_untested](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html?highlight=allow_untested#finetuning_scheduler.fts.FinetuningScheduler.params.allow_untested) to ``True``.\n",
+ "Note that most currently unsupported strategies are so because they require varying degrees of modification to be compatible. For example, ``deepspeed`` will require a [StrategyAdapter](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.strategy_adapters.html#finetuning_scheduler.strategy_adapters.StrategyAdapter) to be written (similar to the one for ``FSDP``, [FSDPStrategyAdapter](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.strategy_adapters.html#finetuning_scheduler.strategy_adapters.FSDPStrategyAdapter)) before support can be added (PRs welcome!),\n",
+ "while ``tpu_spawn`` would require an override of the current broadcast method to include python objects.\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e1282094",
+ "metadata": {
+ "papermill": {
+ "duration": 0.012755,
+ "end_time": "2023-10-04T01:00:31.721575",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.708820",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Scheduled Fine-Tuning For SuperGLUE\n",
+ "\n",
+ "The following example demonstrates the use of [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) to fine-tune a small foundation model on the [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task of [SuperGLUE](https://super.gluebenchmark.com/). Iterative early-stopping will be applied according to a user-specified schedule.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "73da1a1d",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:31.749268Z",
+ "iopub.status.busy": "2023-10-04T01:00:31.748867Z",
+ "iopub.status.idle": "2023-10-04T01:00:34.666615Z",
+ "shell.execute_reply": "2023-10-04T01:00:34.665312Z"
+ },
+ "papermill": {
+ "duration": 2.934229,
+ "end_time": "2023-10-04T01:00:34.668840",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:31.734611",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import warnings\n",
+ "from datetime import datetime\n",
+ "from typing import Any, Dict, Optional\n",
+ "\n",
+ "import datasets\n",
+ "import evaluate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "55549b1f",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:34.706343Z",
+ "iopub.status.busy": "2023-10-04T01:00:34.705931Z",
+ "iopub.status.idle": "2023-10-04T01:00:36.819413Z",
+ "shell.execute_reply": "2023-10-04T01:00:36.818110Z"
+ },
+ "papermill": {
+ "duration": 2.134867,
+ "end_time": "2023-10-04T01:00:36.821957",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:34.687090",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Import the `FinetuningScheduler` PyTorch Lightning extension module we want to use. This will import all necessary callbacks.\n",
+ "import finetuning_scheduler as fts # isort: split\n",
+ "\n",
+ "import lightning as L\n",
+ "import sentencepiece as sp # noqa: F401 # isort: split\n",
+ "import torch\n",
+ "from datasets import logging as datasets_logging\n",
+ "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
+ "from lightning.pytorch.loggers.tensorboard import TensorBoardLogger\n",
+ "from lightning.pytorch.utilities import rank_zero_warn\n",
+ "from torch.optim.adamw import AdamW\n",
+ "from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts\n",
+ "from torch.utils.data import DataLoader\n",
+ "from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer\n",
+ "from transformers import logging as transformers_logging\n",
+ "from transformers.tokenization_utils_base import BatchEncoding\n",
+ "\n",
+ "# set notebook-level variables\n",
+ "TASK_NUM_LABELS = {\"boolq\": 2, \"rte\": 2}\n",
+ "DEFAULT_TASK = \"rte\"\n",
+ "\n",
+ "# reduce hf logging verbosity to focus on tutorial-relevant code/messages\n",
+ "for hflogger in [transformers_logging, datasets_logging]:\n",
+ " hflogger.set_verbosity_error()\n",
+ "# ignore warnings related tokenizers_parallelism/DataLoader parallelism trade-off and\n",
+ "# expected logging behavior\n",
+ "for warnf in [\n",
+ " r\".*does not have many workers.*\",\n",
+ " r\".*The number of training samples.*\",\n",
+ " r\".*converting to a fast.*\",\n",
+ " r\".*number of training batches.*\",\n",
+ "]:\n",
+ " warnings.filterwarnings(\"ignore\", warnf)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f6d43eda",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:36.878554Z",
+ "iopub.status.busy": "2023-10-04T01:00:36.877476Z",
+ "iopub.status.idle": "2023-10-04T01:00:36.890455Z",
+ "shell.execute_reply": "2023-10-04T01:00:36.889486Z"
+ },
+ "papermill": {
+ "duration": 0.041598,
+ "end_time": "2023-10-04T01:00:36.892086",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:36.850488",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class RteBoolqDataModule(L.LightningDataModule):\n",
+ " \"\"\"A ``LightningDataModule`` designed for both the RTE or BoolQ SuperGLUE Hugging Face datasets.\"\"\"\n",
+ "\n",
+ " TASK_TEXT_FIELD_MAP = {\"rte\": (\"premise\", \"hypothesis\"), \"boolq\": (\"question\", \"passage\")}\n",
+ " LOADER_COLUMNS = (\n",
+ " \"datasets_idx\",\n",
+ " \"input_ids\",\n",
+ " \"token_type_ids\",\n",
+ " \"attention_mask\",\n",
+ " \"start_positions\",\n",
+ " \"end_positions\",\n",
+ " \"labels\",\n",
+ " )\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name_or_path: str,\n",
+ " task_name: str = DEFAULT_TASK,\n",
+ " max_seq_length: int = 128,\n",
+ " train_batch_size: int = 16,\n",
+ " eval_batch_size: int = 16,\n",
+ " tokenizers_parallelism: bool = True,\n",
+ " **dataloader_kwargs: Any,\n",
+ " ):\n",
+ " r\"\"\"Initialize the ``LightningDataModule`` designed for both the RTE or BoolQ SuperGLUE Hugging Face datasets.\n",
+ "\n",
+ " Args:\n",
+ " model_name_or_path (str):\n",
+ " Can be either:\n",
+ " - A string, the ``model id`` of a pretrained model hosted inside a model repo on huggingface.co.\n",
+ " Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced\n",
+ " under a user or organization name, like ``dbmdz/bert-base-german-cased``.\n",
+ " - A path to a ``directory`` containing model weights saved using\n",
+ " :meth:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.\n",
+ " task_name (str, optional): Name of the SuperGLUE task to execute. This module supports 'rte' or 'boolq'.\n",
+ " Defaults to DEFAULT_TASK which is 'rte'.\n",
+ " max_seq_length (int, optional): Length to which we will pad sequences or truncate input. Defaults to 128.\n",
+ " train_batch_size (int, optional): Training batch size. Defaults to 16.\n",
+ " eval_batch_size (int, optional): Batch size to use for validation and testing splits. Defaults to 16.\n",
+ " tokenizers_parallelism (bool, optional): Whether to use parallelism in the tokenizer. Defaults to True.\n",
+ " \\**dataloader_kwargs: Arguments passed when initializing the dataloader.\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " task_name = task_name if task_name in TASK_NUM_LABELS.keys() else DEFAULT_TASK\n",
+ " self.text_fields = self.TASK_TEXT_FIELD_MAP[task_name]\n",
+ " self.dataloader_kwargs = {\n",
+ " \"num_workers\": dataloader_kwargs.get(\"num_workers\", 0),\n",
+ " \"pin_memory\": dataloader_kwargs.get(\"pin_memory\", False),\n",
+ " }\n",
+ " self.save_hyperparameters()\n",
+ " os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\" if self.hparams.tokenizers_parallelism else \"false\"\n",
+ " self.tokenizer = AutoTokenizer.from_pretrained(\n",
+ " self.hparams.model_name_or_path, use_fast=True, local_files_only=False\n",
+ " )\n",
+ "\n",
+ " def prepare_data(self):\n",
+ " \"\"\"Load the SuperGLUE dataset.\"\"\"\n",
+ " # N.B. PL calls prepare_data from a single process (rank 0) so do not use it to assign\n",
+ " # state (e.g. self.x=y)\n",
+ " datasets.load_dataset(\"super_glue\", self.hparams.task_name)\n",
+ "\n",
+ " def setup(self, stage):\n",
+ " \"\"\"Setup our dataset splits for training/validation.\"\"\"\n",
+ " self.dataset = datasets.load_dataset(\"super_glue\", self.hparams.task_name)\n",
+ " for split in self.dataset.keys():\n",
+ " self.dataset[split] = self.dataset[split].map(\n",
+ " self._convert_to_features, batched=True, remove_columns=[\"label\"]\n",
+ " )\n",
+ " self.columns = [c for c in self.dataset[split].column_names if c in self.LOADER_COLUMNS]\n",
+ " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
+ "\n",
+ " self.eval_splits = [x for x in self.dataset.keys() if \"validation\" in x]\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.dataset[\"train\"], batch_size=self.hparams.train_batch_size, **self.dataloader_kwargs)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.dataset[\"validation\"], batch_size=self.hparams.eval_batch_size, **self.dataloader_kwargs)\n",
+ "\n",
+ " def _convert_to_features(self, example_batch: datasets.arrow_dataset.LazyDict) -> BatchEncoding:\n",
+ " \"\"\"Convert raw text examples to a :class:`~transformers.tokenization_utils_base.BatchEncoding` container\n",
+ " (derived from python dict) of features that includes helpful methods for translating between word/character\n",
+ " space and token space.\n",
+ "\n",
+ " Args:\n",
+ " example_batch ([type]): The set of examples to convert to token space.\n",
+ "\n",
+ " Returns:\n",
+ " ``BatchEncoding``: A batch of encoded examples (note default tokenizer batch_size=1000).\n",
+ " \"\"\"\n",
+ " text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
+ " # Tokenize the text/text pairs\n",
+ " features = self.tokenizer.batch_encode_plus(\n",
+ " text_pairs, max_length=self.hparams.max_seq_length, padding=\"longest\", truncation=True\n",
+ " )\n",
+ " # Rename label to labels to make it easier to pass to model forward\n",
+ " features[\"labels\"] = example_batch[\"label\"]\n",
+ " return features"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "36ebe774",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:36.929207Z",
+ "iopub.status.busy": "2023-10-04T01:00:36.928956Z",
+ "iopub.status.idle": "2023-10-04T01:00:36.945376Z",
+ "shell.execute_reply": "2023-10-04T01:00:36.944252Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.036949,
+ "end_time": "2023-10-04T01:00:36.947067",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:36.910118",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class RteBoolqModule(L.LightningModule):\n",
+ " \"\"\"A ``LightningModule`` that can be used to fine-tune a foundation model on either the RTE or BoolQ SuperGLUE\n",
+ " tasks using Hugging Face implementations of a given model and the `SuperGLUE Hugging Face dataset.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name_or_path: str,\n",
+ " optimizer_init: Dict[str, Any],\n",
+ " lr_scheduler_init: Dict[str, Any],\n",
+ " model_cfg: Optional[Dict[str, Any]] = None,\n",
+ " task_name: str = DEFAULT_TASK,\n",
+ " experiment_tag: str = \"default\",\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " model_name_or_path (str): Path to pretrained model or identifier from https://huggingface.co/models.\n",
+ " optimizer_init (Dict[str, Any]): The desired optimizer configuration.\n",
+ " lr_scheduler_init (Dict[str, Any]): The desired learning rate scheduler config.\n",
+ " model_cfg (Optional[Dict[str, Any]], optional): Defines overrides of the default model config. Defaults to\n",
+ " ``None``.\n",
+ " task_name (str, optional): The SuperGLUE task to execute, one of ``'rte'``, ``'boolq'``. Defaults to \"rte\".\n",
+ " experiment_tag (str, optional): The tag to use for the experiment and tensorboard logs. Defaults to\n",
+ " \"default\".\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " if task_name not in TASK_NUM_LABELS.keys():\n",
+ " rank_zero_warn(f\"Invalid task_name {task_name!r}. Proceeding with the default task: {DEFAULT_TASK!r}\")\n",
+ " task_name = DEFAULT_TASK\n",
+ " self.num_labels = TASK_NUM_LABELS[task_name]\n",
+ " self.model_cfg = model_cfg or {}\n",
+ " conf = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels, local_files_only=False)\n",
+ " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=conf)\n",
+ " self.model.config.update(self.model_cfg) # apply model config overrides\n",
+ " self.init_hparams = {\n",
+ " \"optimizer_init\": optimizer_init,\n",
+ " \"lr_scheduler_init\": lr_scheduler_init,\n",
+ " \"model_config\": self.model.config,\n",
+ " \"model_name_or_path\": model_name_or_path,\n",
+ " \"task_name\": task_name,\n",
+ " \"experiment_id\": f\"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{experiment_tag}\",\n",
+ " }\n",
+ " self.save_hyperparameters(self.init_hparams)\n",
+ " self.metric = evaluate.load(\"super_glue\", self.hparams.task_name, experiment_id=self.hparams.experiment_id)\n",
+ " self.no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
+ "\n",
+ " @property\n",
+ " def finetuningscheduler_callback(self) -> fts.FinetuningScheduler:\n",
+ " fts_callback = [c for c in self.trainer.callbacks if isinstance(c, fts.FinetuningScheduler)]\n",
+ " return fts_callback[0] if fts_callback else None\n",
+ "\n",
+ " def forward(self, **inputs):\n",
+ " return self.model(**inputs)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx: int):\n",
+ " loss = self(**batch)[0]\n",
+ " self.log(\"train_loss\", loss, prog_bar=True)\n",
+ " return loss\n",
+ "\n",
+ " def on_train_epoch_end(self):\n",
+ " if self.finetuningscheduler_callback:\n",
+ " self.log(\"finetuning_schedule_depth\", float(self.finetuningscheduler_callback.curr_depth))\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
+ " outputs = self(**batch)\n",
+ " val_loss, logits = outputs[:2]\n",
+ " if self.num_labels >= 1:\n",
+ " preds = torch.argmax(logits, axis=1)\n",
+ " elif self.num_labels == 1:\n",
+ " preds = logits.squeeze()\n",
+ " labels = batch[\"labels\"]\n",
+ " self.log(\"val_loss\", val_loss, prog_bar=True)\n",
+ " metric_dict = self.metric.compute(predictions=preds, references=labels)\n",
+ " self.log_dict(metric_dict, prog_bar=True)\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " # With FTS >= 2.0, ``FinetuningScheduler`` simplifies initial optimizer configuration by ensuring the optimizer\n",
+ " # configured here will optimize the parameters (and only those parameters) scheduled to be optimized in phase 0\n",
+ " # of the current fine-tuning schedule. This auto-configuration can be disabled if desired by setting\n",
+ " # ``enforce_phase0_params`` to ``False``.\n",
+ " optimizer = AdamW(params=self.model.parameters(), **self.hparams.optimizer_init)\n",
+ " scheduler = {\n",
+ " \"scheduler\": CosineAnnealingWarmRestarts(optimizer, **self.hparams.lr_scheduler_init),\n",
+ " \"interval\": \"epoch\",\n",
+ " }\n",
+ " return [optimizer], [scheduler]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5bd6e1dc",
+ "metadata": {
+ "papermill": {
+ "duration": 0.017551,
+ "end_time": "2023-10-04T01:00:36.982270",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:36.964719",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Our Training Sessions\n",
+ "\n",
+ "We'll be comparing three different fine-tuning training configurations. Every configuration in this example depends\n",
+ "upon a shared set of defaults, only differing in their respective fine-tuning schedules.\n",
+ "\n",
+ "| Experiment Tag | Training Scenario Description |\n",
+ "|:-----------------:| ---------------------------------------------------------------------- |\n",
+ "| ``fts_explicit`` | Training with a fine-tuning schedule explicitly provided by the user |\n",
+ "| ``nofts_baseline``| A baseline fine-tuning training session (without scheduled fine-tuning) |\n",
+ "| ``fts_implicit`` | Training with an implicitly generated fine-tuning schedule (the default) |\n",
+ "\n",
+ "Let's begin by configuring the ``fts_explicit`` scenario. We'll subsequently run the other two scenarios for\n",
+ "comparison."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "3188c155",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:37.010180Z",
+ "iopub.status.busy": "2023-10-04T01:00:37.009800Z",
+ "iopub.status.idle": "2023-10-04T01:00:37.017080Z",
+ "shell.execute_reply": "2023-10-04T01:00:37.015947Z"
+ },
+ "papermill": {
+ "duration": 0.022944,
+ "end_time": "2023-10-04T01:00:37.018630",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:36.995686",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Let's create a fine-tuning schedule for our model and run an explicitly scheduled fine-tuning training scenario with it\n",
+ "# Please see the [FinetuningScheduler documentation](https://finetuning-scheduler.readthedocs.io/en/stable/index.html) for a full description of the schedule format\n",
+ "\n",
+ "\n",
+ "ft_schedule_yaml = \"\"\"\n",
+ "0:\n",
+ " params:\n",
+ " - model.classifier.bias\n",
+ " - model.classifier.weight\n",
+ " - model.pooler.dense.bias\n",
+ " - model.pooler.dense.weight\n",
+ " - model.deberta.encoder.LayerNorm.bias\n",
+ " - model.deberta.encoder.LayerNorm.weight\n",
+ " - model.deberta.encoder.rel_embeddings.weight\n",
+ " - model.deberta.encoder.layer.{0,11}.(output|attention|intermediate).*\n",
+ "1:\n",
+ " params:\n",
+ " - model.deberta.embeddings.LayerNorm.bias\n",
+ " - model.deberta.embeddings.LayerNorm.weight\n",
+ "2:\n",
+ " params:\n",
+ " - model.deberta.embeddings.word_embeddings.weight\n",
+ "\"\"\"\n",
+ "ft_schedule_name = \"RteBoolqModule_ft_schedule_deberta_base.yaml\"\n",
+ "# Let's write the schedule to a file so we can simulate loading an explicitly defined fine-tuning\n",
+ "# schedule.\n",
+ "with open(ft_schedule_name, \"w\") as f:\n",
+ " f.write(ft_schedule_yaml)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d32e3c9d",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:37.046576Z",
+ "iopub.status.busy": "2023-10-04T01:00:37.045729Z",
+ "iopub.status.idle": "2023-10-04T01:00:38.872986Z",
+ "shell.execute_reply": "2023-10-04T01:00:38.871591Z"
+ },
+ "papermill": {
+ "duration": 1.843878,
+ "end_time": "2023-10-04T01:00:38.875595",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:37.031717",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 42\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7c6bad07fce443efb64239924c3c67ca",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/52.0 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "11fa1564b8fa4190936effa34750e802",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/579 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "fb912cd28b2a4cc6b6b40f6d221bd168",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading spm.model: 0%| | 0.00/2.46M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "datasets.logging.disable_progress_bar()\n",
+ "L.seed_everything(42)\n",
+ "dm = RteBoolqDataModule(model_name_or_path=\"microsoft/deberta-v3-base\", tokenizers_parallelism=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "21a8166d",
+ "metadata": {
+ "papermill": {
+ "duration": 0.025588,
+ "end_time": "2023-10-04T01:00:38.930823",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:38.905235",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Optimizer Configuration\n",
+ "\n",
+ "Though other optimizers can arguably yield some marginal advantage contingent on the context,\n",
+ "the Adam optimizer (and the [AdamW version](https://pytorch.org/docs/stable/_modules/torch/optim/adamw.html#AdamW) which\n",
+ "implements decoupled weight decay) remains robust to hyperparameter choices and is commonly used for fine-tuning\n",
+ "foundation language models. See (Sivaprasad et al., 2020) and (Mosbach, Andriushchenko & Klakow, 2020) for theoretical and systematic empirical justifications of Adam and its use in fine-tuning\n",
+ "large transformer-based language models. The values used here have some justification\n",
+ "in the referenced literature but have been largely empirically determined and while a good\n",
+ "starting point could be could be further tuned.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "4e3732b0",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:38.967519Z",
+ "iopub.status.busy": "2023-10-04T01:00:38.966692Z",
+ "iopub.status.idle": "2023-10-04T01:00:38.972096Z",
+ "shell.execute_reply": "2023-10-04T01:00:38.971224Z"
+ },
+ "papermill": {
+ "duration": 0.023705,
+ "end_time": "2023-10-04T01:00:38.973452",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:38.949747",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "optimizer_init = {\"weight_decay\": 1e-05, \"eps\": 1e-07, \"lr\": 1e-05}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "225dfe6d",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.013358,
+ "end_time": "2023-10-04T01:00:39.000165",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:38.986807",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### LR Scheduler Configuration\n",
+ "\n",
+ "The [CosineAnnealingWarmRestarts scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html?highlight=cosineannealingwarm#torch.optim.lr_scheduler.CosineAnnealingWarmRestarts) nicely fits with our iterative fine-tuning since it does not depend upon a global max_epoch\n",
+ "value. The importance of initial warmup is reduced due to the innate warmup effect of Adam bias correction [5]\n",
+ "and the gradual thawing we are performing. Note that commonly used LR schedulers that depend on providing\n",
+ "max_iterations/epochs (e.g. the\n",
+ "[CosineWarmupScheduler](https://github.com/Lightning-AI/tutorials/blob/0c325829101d5a6ebf32ed99bbf5b09badf04a59/course_UvA-DL/05-transformers-and-MH-attention/Transformers_MHAttention.py#L688)\n",
+ "used in other pytorch-lightning tutorials) also work with FinetuningScheduler. Though the LR scheduler is theoretically\n",
+ "justified (Loshchilov & Hutter, 2016), the particular values provided here are primarily empircally driven.\n",
+ "\n",
+ "[FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) also supports both optimizer and LR scheduler\n",
+ "reinitialization in explicit and implicit finetuning schedule modes. See the advanced usage documentation ([LR scheduler reinitialization](https://finetuning-scheduler.readthedocs.io/en/stable/advanced/lr_scheduler_reinitialization.html), [optimizer reinitialization](https://finetuning-scheduler.readthedocs.io/en/stable/advanced/optimizer_reinitialization.html)) for explanations and demonstration of the extension's support for more complex requirements.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "72ba258b",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:39.028023Z",
+ "iopub.status.busy": "2023-10-04T01:00:39.027622Z",
+ "iopub.status.idle": "2023-10-04T01:00:39.032984Z",
+ "shell.execute_reply": "2023-10-04T01:00:39.031883Z"
+ },
+ "papermill": {
+ "duration": 0.020776,
+ "end_time": "2023-10-04T01:00:39.034353",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:39.013577",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "lr_scheduler_init = {\"T_0\": 1, \"T_mult\": 2, \"eta_min\": 1e-07}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "730e4cfe",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:39.063256Z",
+ "iopub.status.busy": "2023-10-04T01:00:39.062452Z",
+ "iopub.status.idle": "2023-10-04T01:00:49.302267Z",
+ "shell.execute_reply": "2023-10-04T01:00:49.301385Z"
+ },
+ "papermill": {
+ "duration": 10.256623,
+ "end_time": "2023-10-04T01:00:49.304522",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:39.047899",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "473160f448d9410d83df085bd4afb760",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading pytorch_model.bin: 0%| | 0.00/371M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "714e6201592d499e8c72619afc9b1a1a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading builder script: 0%| | 0.00/9.64k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "12246c0087114d80bbc4c6e2cd18b9f2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading extra modules: 0%| | 0.00/3.72k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Load our lightning module...\n",
+ "lightning_module_kwargs = {\n",
+ " \"model_name_or_path\": \"microsoft/deberta-v3-base\",\n",
+ " \"optimizer_init\": optimizer_init,\n",
+ " \"lr_scheduler_init\": lr_scheduler_init,\n",
+ "}\n",
+ "model = RteBoolqModule(**lightning_module_kwargs, experiment_tag=\"fts_explicit\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "26cb3d16",
+ "metadata": {
+ "papermill": {
+ "duration": 0.022265,
+ "end_time": "2023-10-04T01:00:49.358272",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:49.336007",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Callback Configuration\n",
+ "\n",
+ "The only callback required to invoke the [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) is the [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) callback itself.\n",
+ "Default versions of [FTSCheckpoint](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSCheckpoint) and [FTSEarlyStopping](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts_supporters.html#finetuning_scheduler.fts_supporters.FTSEarlyStopping)\n",
+ "(if not specifying ``epoch_only_transitions``) will be included ([as discussed above](#Basic-Usage)) if not provided\n",
+ "in the callbacks list. For demonstration purposes I'm including example configurations of all three callbacks below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "271963d3",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:49.393036Z",
+ "iopub.status.busy": "2023-10-04T01:00:49.392599Z",
+ "iopub.status.idle": "2023-10-04T01:00:49.402246Z",
+ "shell.execute_reply": "2023-10-04T01:00:49.401406Z"
+ },
+ "papermill": {
+ "duration": 0.026823,
+ "end_time": "2023-10-04T01:00:49.403795",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:49.376972",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# let's save our callback configurations for the explicit scenario since we'll be reusing the same\n",
+ "# configurations for the implicit and nofts_baseline scenarios (except the config for the\n",
+ "# FinetuningScheduler callback itself of course in the case of nofts_baseline)\n",
+ "earlystopping_kwargs = {\"monitor\": \"val_loss\", \"min_delta\": 0.001, \"patience\": 2}\n",
+ "checkpoint_kwargs = {\"monitor\": \"val_loss\", \"save_top_k\": 1}\n",
+ "fts_kwargs = {\"max_depth\": 1}\n",
+ "callbacks = [\n",
+ " fts.FinetuningScheduler(ft_schedule=ft_schedule_name, **fts_kwargs),\n",
+ " fts.FTSEarlyStopping(**earlystopping_kwargs),\n",
+ " fts.FTSCheckpoint(**checkpoint_kwargs),\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "14dffa1f",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:49.433384Z",
+ "iopub.status.busy": "2023-10-04T01:00:49.432533Z",
+ "iopub.status.idle": "2023-10-04T01:00:49.437803Z",
+ "shell.execute_reply": "2023-10-04T01:00:49.436991Z"
+ },
+ "papermill": {
+ "duration": 0.021485,
+ "end_time": "2023-10-04T01:00:49.439146",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:49.417661",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "logger = TensorBoardLogger(\"lightning_logs\", name=\"fts_explicit\")\n",
+ "# optionally start tensorboard and monitor progress graphically while viewing multi-phase fine-tuning specific training\n",
+ "# logs in the cell output below by uncommenting the next 2 lines\n",
+ "# %load_ext tensorboard\n",
+ "# %tensorboard --logdir lightning_logs\n",
+ "# disable progress bar by default to focus on multi-phase training logs. Set to True to re-enable if desired\n",
+ "enable_progress_bar = False"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "71e47603",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:00:49.468452Z",
+ "iopub.status.busy": "2023-10-04T01:00:49.468088Z",
+ "iopub.status.idle": "2023-10-04T01:02:20.778752Z",
+ "shell.execute_reply": "2023-10-04T01:02:20.777585Z"
+ },
+ "papermill": {
+ "duration": 91.328395,
+ "end_time": "2023-10-04T01:02:20.781461",
+ "exception": false,
+ "start_time": "2023-10-04T01:00:49.453066",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using 16bit Automatic Mixed Precision (AMP)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Note given the computation associated w/ the multiple phases of fine-tuning demonstrated, this notebook is best used with an accelerator\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Missing logger folder: lightning_logs/fts_explicit\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "fine-tuning schedule dumped to lightning_logs/fts_explicit/version_0/RteBoolqModule_ft_schedule.yaml.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [4,5]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "After executing the provided `configure_optimizers` method, the optimizer state differs from the configuration FinetuningScheduler expected at the beginning of scheduled fine-tuning (phase 0).\n",
+ "Since `enforce_phase0_params` is currently set to `True` (the default), FinetuningScheduler has reconfigured the optimizer to optimize the parameters (and only those parameters) scheduled to be optimized in phase 0 of the current fine-tuning schedule.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " | Name | Type | Params\n",
+ "-------------------------------------------------------------\n",
+ "0 | model | DebertaV2ForSequenceClassification | 184 M \n",
+ "-------------------------------------------------------------\n",
+ "86.0 M Trainable params\n",
+ "98.4 M Non-trainable params\n",
+ "184 M Total params\n",
+ "737.695 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Restoring states from the checkpoint path at lightning_logs/fts_explicit/version_0/checkpoints/epoch=1-step=312.ckpt\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Restored all states from the checkpoint at lightning_logs/fts_explicit/version_0/checkpoints/epoch=1-step=312.ckpt\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Multi-phase fine-tuned training continuing at level 1.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Given the current configuration of `max_depth` (1), this training session will now end when the early stopping conditions are met or `max_epochs` (100) is reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "def train() -> None:\n",
+ " trainer = L.Trainer(\n",
+ " enable_progress_bar=enable_progress_bar,\n",
+ " max_epochs=100,\n",
+ " precision=\"16-mixed\",\n",
+ " accelerator=\"auto\",\n",
+ " devices=1,\n",
+ " callbacks=callbacks,\n",
+ " logger=logger,\n",
+ " )\n",
+ " trainer.fit(model, datamodule=dm)\n",
+ "\n",
+ "\n",
+ "print(\n",
+ " \"Note given the computation associated w/ the multiple phases of fine-tuning demonstrated, this notebook is best used with an accelerator\"\n",
+ ")\n",
+ "train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cf8a19a2",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 0.027017,
+ "end_time": "2023-10-04T01:02:20.842272",
+ "exception": false,
+ "start_time": "2023-10-04T01:02:20.815255",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Running the Baseline and Implicit Fine-Tuning Scenarios\n",
+ "\n",
+ "Let's now compare our ``nofts_baseline`` and ``fts_implicit`` scenarios with the ``fts_explicit`` one we just ran.\n",
+ "\n",
+ "We'll need to update our callbacks list, using the core PL ``EarlyStopping`` and ``ModelCheckpoint`` callbacks for the\n",
+ "``nofts_baseline`` (which operate identically to their FTS analogs apart from the recursive training support).\n",
+ "For both core Lightning and user-registered callbacks, we can define our callbacks using a dictionary as we do\n",
+ "with the LightningCLI. This allows us to avoid managing imports and support more complex configuration separated from\n",
+ "code.\n",
+ "\n",
+ "Note that we'll be using identical callback configurations to the ``fts_explicit`` scenario. Keeping [max_depth](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html?highlight=max_depth#finetuning_scheduler.fts.FinetuningScheduler.params.max_depth) for\n",
+ "the implicit schedule will limit fine-tuning to just the last 4 parameters of the model, which is only a small fraction\n",
+ "of the parameters you'd want to tune for maximum performance. Since the implicit schedule is quite computationally\n",
+ "intensive and most useful for exploring model behavior, leaving [max_depth](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html?highlight=max_depth#finetuning_scheduler.fts.FinetuningScheduler.params.max_depth) 1 allows us to demo implicit mode\n",
+ "behavior while keeping the computational cost and runtime of this notebook reasonable. To review how a full implicit\n",
+ "mode run compares to the ``nofts_baseline`` and ``fts_explicit`` scenarios, please see the the following\n",
+ "[tensorboard experiment summary](https://tensorboard.dev/experiment/n7U8XhrzRbmvVzC4SQSpWw/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "d604905a",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:02:20.879250Z",
+ "iopub.status.busy": "2023-10-04T01:02:20.879059Z",
+ "iopub.status.idle": "2023-10-04T01:02:20.885929Z",
+ "shell.execute_reply": "2023-10-04T01:02:20.884996Z"
+ },
+ "papermill": {
+ "duration": 0.024311,
+ "end_time": "2023-10-04T01:02:20.887249",
+ "exception": false,
+ "start_time": "2023-10-04T01:02:20.862938",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "nofts_callbacks = [EarlyStopping(**earlystopping_kwargs), ModelCheckpoint(**checkpoint_kwargs)]\n",
+ "fts_implicit_callbacks = [\n",
+ " fts.FinetuningScheduler(**fts_kwargs),\n",
+ " fts.FTSEarlyStopping(**earlystopping_kwargs),\n",
+ " fts.FTSCheckpoint(**checkpoint_kwargs),\n",
+ "]\n",
+ "scenario_callbacks = {\"nofts_baseline\": nofts_callbacks, \"fts_implicit\": fts_implicit_callbacks}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "29ee874b",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-10-04T01:02:20.917777Z",
+ "iopub.status.busy": "2023-10-04T01:02:20.917414Z",
+ "iopub.status.idle": "2023-10-04T01:04:10.217300Z",
+ "shell.execute_reply": "2023-10-04T01:04:10.216227Z"
+ },
+ "papermill": {
+ "duration": 109.31813,
+ "end_time": "2023-10-04T01:04:10.220002",
+ "exception": false,
+ "start_time": "2023-10-04T01:02:20.901872",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using 16bit Automatic Mixed Precision (AMP)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Beginning training the 'nofts_baseline' scenario\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Missing logger folder: lightning_logs/nofts_baseline\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [4,5]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " | Name | Type | Params\n",
+ "-------------------------------------------------------------\n",
+ "0 | model | DebertaV2ForSequenceClassification | 184 M \n",
+ "-------------------------------------------------------------\n",
+ "184 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "184 M Total params\n",
+ "737.695 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using 16bit Automatic Mixed Precision (AMP)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Beginning training the 'fts_implicit' scenario\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Missing logger folder: lightning_logs/fts_implicit\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "fine-tuning schedule dumped to lightning_logs/fts_implicit/version_0/RteBoolqModule_ft_schedule.yaml.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Generated default fine-tuning schedule 'lightning_logs/fts_implicit/version_0/RteBoolqModule_ft_schedule.yaml' for iterative fine-tuning\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [4,5]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "After executing the provided `configure_optimizers` method, the optimizer state differs from the configuration FinetuningScheduler expected at the beginning of scheduled fine-tuning (phase 0).\n",
+ "Since `enforce_phase0_params` is currently set to `True` (the default), FinetuningScheduler has reconfigured the optimizer to optimize the parameters (and only those parameters) scheduled to be optimized in phase 0 of the current fine-tuning schedule.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " | Name | Type | Params\n",
+ "-------------------------------------------------------------\n",
+ "0 | model | DebertaV2ForSequenceClassification | 184 M \n",
+ "-------------------------------------------------------------\n",
+ "1.5 K Trainable params\n",
+ "184 M Non-trainable params\n",
+ "184 M Total params\n",
+ "737.695 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Restoring states from the checkpoint path at lightning_logs/fts_implicit/version_0/checkpoints/epoch=0-step=156.ckpt\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Restored all states from the checkpoint at lightning_logs/fts_implicit/version_0/checkpoints/epoch=0-step=156.ckpt\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Multi-phase fine-tuned training continuing at level 1.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Given the current configuration of `max_depth` (1), this training session will now end when the early stopping conditions are met or `max_epochs` (100) is reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "for scenario_name, scenario_callbacks in scenario_callbacks.items():\n",
+ " model = RteBoolqModule(**lightning_module_kwargs, experiment_tag=scenario_name)\n",
+ " logger = TensorBoardLogger(\"lightning_logs\", name=scenario_name)\n",
+ " callbacks = scenario_callbacks\n",
+ " print(f\"Beginning training the '{scenario_name}' scenario\")\n",
+ " train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c33eb874",
+ "metadata": {
+ "lines_to_next_cell": 0,
+ "papermill": {
+ "duration": 0.027773,
+ "end_time": "2023-10-04T01:04:10.284669",
+ "exception": false,
+ "start_time": "2023-10-04T01:04:10.256896",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### Reviewing the Training Results\n",
+ "\n",
+ "See the [tensorboard experiment summaries](https://tensorboard.dev/experiment/n7U8XhrzRbmvVzC4SQSpWw/) to get a sense\n",
+ "of the relative computational and performance tradeoffs associated with these [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) configurations.\n",
+ "The summary compares a full ``fts_implicit`` execution to ``fts_explicit`` and ``nofts_baseline`` scenarios using DDP\n",
+ "training with 2 GPUs. The full logs/schedules for all three scenarios are available\n",
+ "[here](https://drive.google.com/file/d/1LrUcisRLHeJgh_BDOOD_GUBPp5iHAkoR/view?usp=sharing) and the checkpoints\n",
+ "produced in the scenarios [here](https://drive.google.com/file/d/1t7myBgcqcZ9ax_IT9QVk-vFH_l_o5UXB/view?usp=sharing)\n",
+ "(caution, ~3.5GB).\n",
+ "\n",
+ "[{height=\"315px\" width=\"492px\"}](https://tensorboard.dev/experiment/n7U8XhrzRbmvVzC4SQSpWw/#scalars&_smoothingWeight=0&runSelectionState=eyJmdHNfZXhwbGljaXQiOnRydWUsIm5vZnRzX2Jhc2VsaW5lIjpmYWxzZSwiZnRzX2ltcGxpY2l0IjpmYWxzZX0%3D)\n",
+ "[{height=\"316px\" width=\"505px\"}](https://tensorboard.dev/experiment/n7U8XhrzRbmvVzC4SQSpWw/#scalars&_smoothingWeight=0&runSelectionState=eyJmdHNfZXhwbGljaXQiOmZhbHNlLCJub2Z0c19iYXNlbGluZSI6dHJ1ZSwiZnRzX2ltcGxpY2l0IjpmYWxzZX0%3D)\n",
+ "\n",
+ "Note that given execution context differences, there could be a modest variation in performance from the tensorboard summaries generated by this notebook.\n",
+ "\n",
+ "[FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) expands the space of possible fine-tuning schedules and the composition of more sophisticated schedules can\n",
+ "yield marginal fine-tuning performance gains. That stated, it should be emphasized the primary utility of [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) is to grant\n",
+ "greater fine-tuning flexibility for model exploration in research. For example, glancing at DeBERTa-v3's implicit training\n",
+ "run, a critical tuning transition point is immediately apparent:\n",
+ "\n",
+ "[{height=\"272px\" width=\"494px\"}](https://tensorboard.dev/experiment/n7U8XhrzRbmvVzC4SQSpWw/#scalars&_smoothingWeight=0&runSelectionState=eyJmdHNfZXhwbGljaXQiOmZhbHNlLCJub2Z0c19iYXNlbGluZSI6ZmFsc2UsImZ0c19pbXBsaWNpdCI6dHJ1ZX0%3D)\n",
+ "\n",
+ "Our `val_loss` begins a precipitous decline at step 3119 which corresponds to phase 17 in the schedule. Referring to our\n",
+ "schedule, in phase 17 we're beginning tuning the attention parameters of our 10th encoder layer (of 11). Interesting!\n",
+ "Though beyond the scope of this tutorial, it might be worth investigating these dynamics further and\n",
+ "[FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) allows one to do just that quite easily.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0d7af80a",
+ "metadata": {
+ "lines_to_next_cell": 0,
+ "papermill": {
+ "duration": 0.015729,
+ "end_time": "2023-10-04T01:04:10.321609",
+ "exception": false,
+ "start_time": "2023-10-04T01:04:10.305880",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "\n",
+ "Note that though this example is intended to capture a common usage scenario, substantial variation is expected\n",
+ "among use cases and models.\n",
+ "In summary, [FinetuningScheduler](https://finetuning-scheduler.readthedocs.io/en/stable/api/finetuning_scheduler.fts.html#finetuning_scheduler.fts.FinetuningScheduler) provides increased fine-tuning flexibility that can be useful in a variety of\n",
+ "contexts from exploring model tuning behavior to maximizing performance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "78b6eb63",
+ "metadata": {
+ "papermill": {
+ "duration": 0.015777,
+ "end_time": "2023-10-04T01:04:10.353220",
+ "exception": false,
+ "start_time": "2023-10-04T01:04:10.337443",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Footnotes\n",
+ "\n",
+ "- [Howard, J., & Ruder, S. (2018)](https://arxiv.org/pdf/1801.06146.pdf). Fine-tuned Language\n",
+ " Models for Text Classification. ArXiv, abs/1801.06146. [↩](#Scheduled-Fine-Tuning-with-the-Fine-Tuning-Scheduler-Extension)\n",
+ "- [Chronopoulou, A., Baziotis, C., & Potamianos, A. (2019)](https://arxiv.org/pdf/1902.10547.pdf).\n",
+ " An embarrassingly simple approach for transfer learning from pretrained language models. arXiv\n",
+ " preprint arXiv:1902.10547. [↩](#Scheduled-Fine-Tuning-with-the-Fine-Tuning-Scheduler-Extension)\n",
+ "- [Peters, M. E., Ruder, S., & Smith, N. A. (2019)](https://arxiv.org/pdf/1903.05987.pdf). To tune or not to\n",
+ " tune? adapting pretrained representations to diverse tasks. arXiv preprint arXiv:1903.05987. [↩](#Scheduled-Fine-Tuning-with-the-Fine-Tuning-Scheduler-Extension)\n",
+ "- [Sivaprasad, P. T., Mai, F., Vogels, T., Jaggi, M., & Fleuret, F. (2020)](https://arxiv.org/pdf/1910.11758.pdf).\n",
+ " Optimizer benchmarking needs to account for hyperparameter tuning. In International Conference on Machine Learning\n",
+ "(pp. 9036-9045). PMLR. [↩](#Optimizer-Configuration)\n",
+ "- [Mosbach, M., Andriushchenko, M., & Klakow, D. (2020)](https://arxiv.org/pdf/2006.04884.pdf). On the stability of\n",
+ "fine-tuning bert: Misconceptions, explanations, and strong baselines. arXiv preprint arXiv:2006.04884. [↩](#Optimizer-Configuration)\n",
+ "- [Loshchilov, I., & Hutter, F. (2016)](https://arxiv.org/pdf/1608.03983.pdf). Sgdr: Stochastic gradient descent with\n",
+ "warm restarts. arXiv preprint arXiv:1608.03983. [↩](#LR-Scheduler-Configuration)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d6053aea",
+ "metadata": {
+ "papermill": {
+ "duration": 0.015895,
+ "end_time": "2023-10-04T01:04:10.385214",
+ "exception": false,
+ "start_time": "2023-10-04T01:04:10.369319",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Congratulations - Time to Join the Community!\n",
+ "\n",
+ "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n",
+ "movement, you can do so in the following ways!\n",
+ "\n",
+ "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n",
+ "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n",
+ "tools we're building.\n",
+ "\n",
+ "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n",
+ "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n",
+ "and share your interests in `#general` channel\n",
+ "\n",
+ "\n",
+ "### Contributions !\n",
+ "The best way to contribute to our community is to become a code contributor! At any time you can go to\n",
+ "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n",
+ "GitHub Issues page and filter for \"good first issue\".\n",
+ "\n",
+ "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* You can also contribute your own notebooks with useful examples !\n",
+ "\n",
+ "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
+ "\n",
+ "[{height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "colab_type,colab,id,-all",
+ "formats": "ipynb,py:percent",
+ "main_language": "python"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ },
+ "papermill": {
+ "default_parameters": {},
+ "duration": 227.374811,
+ "end_time": "2023-10-04T01:04:12.669591",
+ "environment_variables": {},
+ "exception": null,
+ "input_path": "lightning_examples/finetuning-scheduler/finetuning-scheduler.ipynb",
+ "output_path": ".notebooks/lightning_examples/finetuning-scheduler.ipynb",
+ "parameters": {},
+ "start_time": "2023-10-04T01:00:25.294780",
+ "version": "2.4.0"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {
+ "0045c9e4a65349f4b34b1d3328af2030": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "04559d1a68184a5bb77b963279b18964": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_4ad991ecf5bc4e5a89ca6cd5e38f0702",
+ "placeholder": "",
+ "style": "IPY_MODEL_a7e70b518cb04a4c8a02b5162e725420",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "Downloading spm.model: 100%"
+ }
+ },
+ "095b6a1381b64332bc8a4c226327bfba": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "0bf65bc277464cd29651a43f7eadbe32": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_4044f3641689494280f460d5c0c50958",
+ "placeholder": "",
+ "style": "IPY_MODEL_0045c9e4a65349f4b34b1d3328af2030",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "Downloading (…)lve/main/config.json: 100%"
+ }
+ },
+ "11fa1564b8fa4190936effa34750e802": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0bf65bc277464cd29651a43f7eadbe32",
+ "IPY_MODEL_92fd2de9fe234b5b87c21c687cae4dcf",
+ "IPY_MODEL_58d941f718264bbdb144ed2b756e125e"
+ ],
+ "layout": "IPY_MODEL_7863820e3d2a42638d05a17ce33d9e23",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "12246c0087114d80bbc4c6e2cd18b9f2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_ade14e595c2a40d58069c3ba20f6704e",
+ "IPY_MODEL_73ff1711ddcd4a1b971a05087255e17b",
+ "IPY_MODEL_1ea95007d81e45aabafe10406d27f494"
+ ],
+ "layout": "IPY_MODEL_5bd9e3c200d345ea920cb30cd062b5f7",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "1b96320a9b864da58fbce65c67beff12": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "1ea95007d81e45aabafe10406d27f494": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_eef49f2317a54c1ea26e1012d2b80c53",
+ "placeholder": "",
+ "style": "IPY_MODEL_fb71138dced94a259755cedf48c700ee",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 3.72k/3.72k [00:00<00:00, 607kB/s]"
+ }
+ },
+ "300ab00550724d42b403f3086097f349": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_9761eb5388384cddbcb20b28527f7019",
+ "max": 52.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_39f571a38ad045c8940f3dafedd43cd8",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 52.0
+ }
+ },
+ "302bc617b537419f968fc3a93543d1b8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_b2557bf0946e4b648e5c60ca49fbd09f",
+ "placeholder": "",
+ "style": "IPY_MODEL_32cb3c775dc5423fb0e908d0302b761d",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "Downloading builder script: 100%"
+ }
+ },
+ "32cb3c775dc5423fb0e908d0302b761d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "361fbe61b3d04a8ba9061f147cdb9c65": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "39f571a38ad045c8940f3dafedd43cd8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "4044f3641689494280f460d5c0c50958": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "473160f448d9410d83df085bd4afb760": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_cd2097905c974f6bb7701750424dda9b",
+ "IPY_MODEL_cb0f5c72f4aa4005b98e79e9ce7ed7aa",
+ "IPY_MODEL_d6f9e0f2911c408d8db81271d3eba168"
+ ],
+ "layout": "IPY_MODEL_e4c840e6b8e54a4e81aca7d66b708ffc",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "4ad991ecf5bc4e5a89ca6cd5e38f0702": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5000b81d9bc649d58ff3fbb098e05ba6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "51e56e6d98334f97ba5ce0ac4e8b022f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "525309b45da74927a28f00c6e7c585a2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "58d941f718264bbdb144ed2b756e125e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_72754e179be04771bce8290dc5cf69b5",
+ "placeholder": "",
+ "style": "IPY_MODEL_bd0cfbcf38144ab4932009bd001db2b4",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 579/579 [00:00<00:00, 50.1kB/s]"
+ }
+ },
+ "5bd9e3c200d345ea920cb30cd062b5f7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "60af281fbe3a4382827d4589b4c521dd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "6532ec7c0a6646c892c40a969c1c4281": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "691e6cd7d33b487e9efe0dc09b3afe31": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "714e6201592d499e8c72619afc9b1a1a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_302bc617b537419f968fc3a93543d1b8",
+ "IPY_MODEL_c7b39992515c44edb12e395f5fc825a5",
+ "IPY_MODEL_bd339db0b1944741aad01b778cada398"
+ ],
+ "layout": "IPY_MODEL_361fbe61b3d04a8ba9061f147cdb9c65",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "72754e179be04771bce8290dc5cf69b5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "73ff1711ddcd4a1b971a05087255e17b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_ecbe440d001840ffb30a19eb89c19224",
+ "max": 3724.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9722bae57230451b8e7a2d75c7ecc738",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 3724.0
+ }
+ },
+ "7863820e3d2a42638d05a17ce33d9e23": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7c6bad07fce443efb64239924c3c67ca": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e30bed358a524ff9b847273cfbb02d9f",
+ "IPY_MODEL_300ab00550724d42b403f3086097f349",
+ "IPY_MODEL_a25b673d34a84db39d0d64061265eef9"
+ ],
+ "layout": "IPY_MODEL_d649794571c94841b27791adad7a1c27",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "803da9ca8e614686b24aa71b2413ddc1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8043050af1354c9e8904b82579011fd0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8ba060a6a73440ceac5860a3f512e192": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8be70ce7e278417ea036cf7a802a9e69": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "92fd2de9fe234b5b87c21c687cae4dcf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_525309b45da74927a28f00c6e7c585a2",
+ "max": 579.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_bf6d47d177d945dfa01145b3780e1680",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 579.0
+ }
+ },
+ "9671ec27d8aa45b594f45743e529e204": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_a748911351e243ea8ae3c4be9067ed29",
+ "max": 2464616.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a5fe4b43bd6a416f959db3b18a14e75a",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 2464616.0
+ }
+ },
+ "9722bae57230451b8e7a2d75c7ecc738": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "975f9d8c8483412388be188fad051077": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "9761eb5388384cddbcb20b28527f7019": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a1a56537c1274537a173c82ce76fb004": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a25b673d34a84db39d0d64061265eef9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_e38ee12e1d644e61b07d34c2975f75bd",
+ "placeholder": "",
+ "style": "IPY_MODEL_dd1a41dbc6fb4897bb7f51f50e787b52",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 52.0/52.0 [00:00<00:00, 4.15kB/s]"
+ }
+ },
+ "a5fe4b43bd6a416f959db3b18a14e75a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a748911351e243ea8ae3c4be9067ed29": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a7e70b518cb04a4c8a02b5162e725420": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "ade14e595c2a40d58069c3ba20f6704e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_691e6cd7d33b487e9efe0dc09b3afe31",
+ "placeholder": "",
+ "style": "IPY_MODEL_975f9d8c8483412388be188fad051077",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "Downloading extra modules: 100%"
+ }
+ },
+ "b2557bf0946e4b648e5c60ca49fbd09f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b78b2619981447d1bde12251af48aab3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "bd0cfbcf38144ab4932009bd001db2b4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "bd339db0b1944741aad01b778cada398": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_51e56e6d98334f97ba5ce0ac4e8b022f",
+ "placeholder": "",
+ "style": "IPY_MODEL_60af281fbe3a4382827d4589b4c521dd",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 9.64k/9.64k [00:00<00:00, 1.29MB/s]"
+ }
+ },
+ "bf6d47d177d945dfa01145b3780e1680": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "c7b39992515c44edb12e395f5fc825a5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_8be70ce7e278417ea036cf7a802a9e69",
+ "max": 9644.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a1a56537c1274537a173c82ce76fb004",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 9644.0
+ }
+ },
+ "c7b3fa42d3a1443e9df3c08e90c68ae0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "cb0f5c72f4aa4005b98e79e9ce7ed7aa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_6532ec7c0a6646c892c40a969c1c4281",
+ "max": 371146213.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_1b96320a9b864da58fbce65c67beff12",
+ "tabbable": null,
+ "tooltip": null,
+ "value": 371146213.0
+ }
+ },
+ "cd2097905c974f6bb7701750424dda9b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_8ba060a6a73440ceac5860a3f512e192",
+ "placeholder": "",
+ "style": "IPY_MODEL_095b6a1381b64332bc8a4c226327bfba",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "Downloading pytorch_model.bin: 100%"
+ }
+ },
+ "d649794571c94841b27791adad7a1c27": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d6f9e0f2911c408d8db81271d3eba168": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_8043050af1354c9e8904b82579011fd0",
+ "placeholder": "",
+ "style": "IPY_MODEL_b78b2619981447d1bde12251af48aab3",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 371M/371M [00:07<00:00, 48.7MB/s]"
+ }
+ },
+ "d7138c99622f4e608247e641553235db": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd1a41dbc6fb4897bb7f51f50e787b52": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "e30bed358a524ff9b847273cfbb02d9f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_5000b81d9bc649d58ff3fbb098e05ba6",
+ "placeholder": "",
+ "style": "IPY_MODEL_c7b3fa42d3a1443e9df3c08e90c68ae0",
+ "tabbable": null,
+ "tooltip": null,
+ "value": "Downloading (…)okenizer_config.json: 100%"
+ }
+ },
+ "e38ee12e1d644e61b07d34c2975f75bd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e4c840e6b8e54a4e81aca7d66b708ffc": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ecbe440d001840ffb30a19eb89c19224": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "edeb549f9e9447f8bb31d9c7838406fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "eef49f2317a54c1ea26e1012d2b80c53": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fb71138dced94a259755cedf48c700ee": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "fb912cd28b2a4cc6b6b40f6d221bd168": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_04559d1a68184a5bb77b963279b18964",
+ "IPY_MODEL_9671ec27d8aa45b594f45743e529e204",
+ "IPY_MODEL_feab3d184baa43d59c96ad9d2bb63681"
+ ],
+ "layout": "IPY_MODEL_803da9ca8e614686b24aa71b2413ddc1",
+ "tabbable": null,
+ "tooltip": null
+ }
+ },
+ "feab3d184baa43d59c96ad9d2bb63681": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_d7138c99622f4e608247e641553235db",
+ "placeholder": "",
+ "style": "IPY_MODEL_edeb549f9e9447f8bb31d9c7838406fe",
+ "tabbable": null,
+ "tooltip": null,
+ "value": " 2.46M/2.46M [00:00<00:00, 19.5MB/s]"
+ }
+ }
+ },
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/.notebooks/lightning_examples/finetuning-scheduler.yaml b/.notebooks/lightning_examples/finetuning-scheduler.yaml
new file mode 100644
index 000000000..b2d8d2e84
--- /dev/null
+++ b/.notebooks/lightning_examples/finetuning-scheduler.yaml
@@ -0,0 +1,42 @@
+title: Fine-Tuning Scheduler
+author: "[Dan Dale](https://github.com/speediedan)"
+created: 2021-11-29
+updated: 2023-04-06
+license: CC BY-SA
+build: 0
+tags:
+ - Fine-Tuning
+description:
+ "This notebook introduces the [Fine-Tuning Scheduler](https://finetuning-scheduler.readthedocs.io/en/stable/index.html)
+ extension
+
+ and demonstrates the use of it to fine-tune a small foundation model on the
+
+ [RTE](https://huggingface.co/datasets/viewer/?dataset=super_glue&config=rte) task
+ of
+
+ [SuperGLUE](https://super.gluebenchmark.com/) with iterative early-stopping defined
+ according to a user-specified
+
+ schedule. It uses Hugging Face's ``datasets`` and ``transformers`` libraries to
+ retrieve the relevant benchmark data
+
+ and foundation model weights. The required dependencies are installed via the finetuning-scheduler
+ ``[examples]`` extra.
+
+ "
+requirements:
+ - finetuning-scheduler[examples]>=2.0.0
+ - torch>=1.12.1
+accelerator:
+ - GPU
+environment:
+ - torchmetrics==1.2.0
+ - ipython==8.16.1
+ - finetuning-scheduler==2.0.9
+ - pytorch-lightning==2.0.9.post0
+ - setuptools==68.2.2
+ - urllib3==2.0.6
+ - matplotlib==3.8.0
+ - torch==2.0.1+cu118
+published: "2023-10-04T01:04:14.627522"
diff --git a/.notebooks/lightning_examples/mnist-hello-world.yaml b/.notebooks/lightning_examples/mnist-hello-world.yaml
index 8df527d66..7f9c58aa7 100644
--- a/.notebooks/lightning_examples/mnist-hello-world.yaml
+++ b/.notebooks/lightning_examples/mnist-hello-world.yaml
@@ -5,29 +5,29 @@ updated: 2023-05-15
license: CC BY-SA
build: 0
tags:
-- Image
+ - Image
description: In this notebook, we'll go over the basics of lightning by preparing
models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).
requirements:
-- torchvision
-- torchmetrics >=0.11.0
-- pandas
-- seaborn
-- lightning>=2.0.0
+ - torchvision
+ - torchmetrics >=0.11.0
+ - pandas
+ - seaborn
+ - lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- pytorch-lightning==1.5.3
-- torchvision==0.15.2
-- torch==2.0.1
-- numpy==1.26.4
-- torchmetrics==1.2.1
-- setuptools==69.0.3
-- urllib3==2.2.2
-- pandas==2.2.2
-- ipython==8.16.1
-- matplotlib==3.8.4
-- lightning==2.3.3
-- seaborn==0.13.2
-published: '2024-07-20T00:10:28.589696'
+ - pytorch-lightning==1.5.3
+ - torchvision==0.15.2
+ - torch==2.0.1
+ - numpy==1.26.4
+ - torchmetrics==1.2.1
+ - setuptools==69.0.3
+ - urllib3==2.2.2
+ - pandas==2.2.2
+ - ipython==8.16.1
+ - matplotlib==3.8.4
+ - lightning==2.3.3
+ - seaborn==0.13.2
+published: "2024-07-20T00:10:28.589696"
diff --git a/.notebooks/lightning_examples/mnist-tpu-training.yaml b/.notebooks/lightning_examples/mnist-tpu-training.yaml
index 0cf78f929..70fce688d 100644
--- a/.notebooks/lightning_examples/mnist-tpu-training.yaml
+++ b/.notebooks/lightning_examples/mnist-tpu-training.yaml
@@ -5,24 +5,24 @@ updated: 2023-05-15
license: CC BY-SA
build: 0
tags:
-- Image
+ - Image
description: In this notebook, we'll train a model on TPUs. Updating one Trainer flag
is all you need for that. The most up to documentation related to TPU training can
be found [here](https://lightning.ai/docs/pytorch/stable/accelerators/tpu.html).
requirements:
-- torchvision
-- lightning>=2.0.0
+ - torchvision
+ - lightning>=2.0.0
accelerator:
-- TPU
+ - TPU
environment:
-- setuptools==69.0.3
-- numpy==1.26.4
-- lightning==2.3.3
-- urllib3==2.2.2
-- torch==2.3.1
-- ipython==8.16.1
-- torchmetrics==1.2.1
-- pytorch-lightning==1.5.3
-- matplotlib==3.8.4
-- torchvision==0.18.1
-published: '2024-07-20T00:13:16.860265'
+ - setuptools==69.0.3
+ - numpy==1.26.4
+ - lightning==2.3.3
+ - urllib3==2.2.2
+ - torch==2.3.1
+ - ipython==8.16.1
+ - torchmetrics==1.2.1
+ - pytorch-lightning==1.5.3
+ - matplotlib==3.8.4
+ - torchvision==0.18.1
+published: "2024-07-20T00:13:16.860265"
diff --git a/.notebooks/lightning_examples/reinforce-learning-DQN.yaml b/.notebooks/lightning_examples/reinforce-learning-DQN.yaml
index b122e5d93..e304892e9 100644
--- a/.notebooks/lightning_examples/reinforce-learning-DQN.yaml
+++ b/.notebooks/lightning_examples/reinforce-learning-DQN.yaml
@@ -5,8 +5,8 @@ updated: 2021-12-03
license: CC BY-SA
build: 2
tags:
-- RL
-description: 'Main takeaways:
+ - RL
+description: "Main takeaways:
1. RL has the same flow as previous models we have seen, with a few additions
@@ -17,28 +17,28 @@ description: 'Main takeaways:
3. Each training step carries has the agent taking an action in the environment
and storing the experience in the IterableDataset
- '
+ "
requirements:
-- gym <0.24
-- pygame
-- pandas
-- seaborn
-- lightning>=2.0.0
+ - gym <0.24
+ - pygame
+ - pandas
+ - seaborn
+ - lightning>=2.0.0
accelerator:
-- CPU
-- GPU
+ - CPU
+ - GPU
environment:
-- lightning==2.3.3
-- pytorch-lightning==1.5.3
-- pandas==2.2.2
-- matplotlib==3.8.4
-- seaborn==0.13.2
-- setuptools==69.0.3
-- torchmetrics==1.2.1
-- numpy==1.26.4
-- torch==2.0.1+cu118
-- pygame==2.6.0
-- gym==0.23.1
-- urllib3==2.2.2
-- ipython==8.16.1
-published: '2024-07-20T00:14:43.650667'
+ - lightning==2.3.3
+ - pytorch-lightning==1.5.3
+ - pandas==2.2.2
+ - matplotlib==3.8.4
+ - seaborn==0.13.2
+ - setuptools==69.0.3
+ - torchmetrics==1.2.1
+ - numpy==1.26.4
+ - torch==2.0.1+cu118
+ - pygame==2.6.0
+ - gym==0.23.1
+ - urllib3==2.2.2
+ - ipython==8.16.1
+published: "2024-07-20T00:14:43.650667"
diff --git a/.notebooks/lightning_examples/text-transformers.ipynb b/.notebooks/lightning_examples/text-transformers.ipynb
new file mode 100644
index 000000000..677e734f7
--- /dev/null
+++ b/.notebooks/lightning_examples/text-transformers.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5dc91e4fd43cd4bf2884c85520077e7d4d25375ee86de11c2f8195379a8607e4
+size 874367
diff --git a/.notebooks/lightning_examples/text-transformers.yaml b/.notebooks/lightning_examples/text-transformers.yaml
new file mode 100644
index 000000000..29c52ac50
--- /dev/null
+++ b/.notebooks/lightning_examples/text-transformers.yaml
@@ -0,0 +1,41 @@
+title: Finetune Transformers Models with PyTorch Lightning
+author: Lightning.ai
+created: 2021-01-31
+updated: 2023-03-17
+license: CC BY-SA
+build: 0
+tags:
+ - Text
+description: "This notebook will use HuggingFace's `datasets` library to get data,
+ which will be wrapped in a `LightningDataModule`.
+
+ Then, we write a class to perform text classification on any dataset from the [GLUE
+ Benchmark](https://gluebenchmark.com/).
+
+ (We just show CoLA and MRPC due to constraint on compute/disk)
+
+ "
+requirements:
+ - transformers
+ - datasets
+ - scipy
+ - scikit-learn
+ - torchtext>=0.9
+ - lightning>=2.0.0
+accelerator:
+ - GPU
+environment:
+ - lightning==2.0.9.post0
+ - scipy==1.11.3
+ - urllib3==2.0.6
+ - torch==2.0.1
+ - datasets==2.14.5
+ - torchtext==0.15.2
+ - torchmetrics==1.2.0
+ - setuptools==68.2.2
+ - ipython==8.16.1
+ - scikit-learn==1.3.1
+ - pytorch-lightning==2.0.9.post0
+ - transformers==4.34.0
+ - matplotlib==3.8.0
+published: "2023-10-12T02:41:16.771129"
diff --git a/.notebooks/lightning_examples/warp-drive.ipynb b/.notebooks/lightning_examples/warp-drive.ipynb
new file mode 100644
index 000000000..39b91d272
--- /dev/null
+++ b/.notebooks/lightning_examples/warp-drive.ipynb
@@ -0,0 +1,1519 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "f98c4184",
+ "metadata": {
+ "papermill": {
+ "duration": 0.31986,
+ "end_time": "2022-05-17T23:32:06.666655",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:06.346795",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "\n",
+ "# Multi-agent Reinforcement Learning With WarpDrive\n",
+ "\n",
+ "* **Author:** Sunil Srinivasa (sunil.srinivasa@salesforce.com), Tian Lan (tian.lan@salesforce.com), Huan Wang (huan.wang@salesforce.com) and Stephan Zheng(stephan.zheng@salesforce.com)\n",
+ "* **License:** BSD 3-Clause \"New\" or \"Revised\" License\n",
+ "* **Generated:** 2022-05-18T01:28:32.002384\n",
+ "\n",
+ "This notebook introduces multi-agent reinforcement learning (MARL) with WarpDrive (Lan et al. https://arxiv.org/abs/2108.13976). WarpDrive is a flexible, lightweight, and easy-to-use open-source framework that implements end-to-end deep MARL on GPUs. WarpDrive enables orders-of-magnitude speedups compared to CPU-GPU implementations, using the parallelization capability of GPUs and several design choices to minimize communication overhead. WarpDrive also prioritizes user-friendliness - it has utility functions to easily build MARL environments in CUDA and quality-of-life tools to run end-to-end MARL using just a few lines of code, and is compatible with PyTorch.\n",
+ "WarpDrive includes the following resources. code - https://github.com/salesforce/warp-drive documentation - http://opensource.salesforce.com/warp-drive/, and white paper - https://arxiv.org/abs/2108.13976.\n",
+ "\n",
+ "---\n",
+ "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/warp-drive.ipynb)\n",
+ "\n",
+ "Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
+ "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n",
+ "| Join us [on Slack](https://www.pytorchlightning.ai/community)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "358bebff",
+ "metadata": {
+ "papermill": {
+ "duration": 0.461423,
+ "end_time": "2022-05-17T23:32:07.530601",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:07.069178",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Setup\n",
+ "This notebook requires some packages besides pytorch-lightning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "060104be",
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:08.437379Z",
+ "iopub.status.busy": "2022-05-17T23:32:08.436649Z",
+ "iopub.status.idle": "2022-05-17T23:32:12.595594Z",
+ "shell.execute_reply": "2022-05-17T23:32:12.594658Z"
+ },
+ "id": "LfrJLKPFyhsK",
+ "lines_to_next_cell": 0,
+ "papermill": {
+ "duration": 4.624539,
+ "end_time": "2022-05-17T23:32:12.597709",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:07.973170",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "! pip install --quiet \"ffmpeg-python\" \"rl-warp-drive>=1.6.5\" \"setuptools==59.5.0\" \"ipython[notebook]\" \"torch>=1.8\" \"torch==1.10.*\" \"torchvision==0.11.*\" \"torchtext==0.11.*\" \"torchmetrics>=0.7\" \"pytorch-lightning>=1.4\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "24cc6b67",
+ "metadata": {
+ "papermill": {
+ "duration": 0.441952,
+ "end_time": "2022-05-17T23:32:13.320857",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:12.878905",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "**⚠️ PLEASE NOTE:**\n",
+ "This notebook runs on a GPU runtime. If running on Colab, choose Runtime > Change runtime type from the menu, then select `GPU` in the 'Hardware accelerator' dropdown menu."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "690a892f",
+ "metadata": {
+ "papermill": {
+ "duration": 0.261087,
+ "end_time": "2022-05-17T23:32:13.782483",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:13.521396",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Introduction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3b3f5506",
+ "metadata": {
+ "papermill": {
+ "duration": 0.402247,
+ "end_time": "2022-05-17T23:32:14.385761",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:13.983514",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "This tutorial provides a demonstration of a multi-agent Reinforcement Learning (RL) training loop with [WarpDrive](https://github.com/salesforce/warp-drive). WarpDrive is a flexible, lightweight, and easy-to-use RL framework that implements end-to-end deep multi-agent RL on a GPU (Graphics Processing Unit). Using the extreme parallelization capability of GPUs, it enables [orders-of-magnitude faster RL](https://arxiv.org/abs/2108.13976) compared to common implementations that blend CPU simulations and GPU models. WarpDrive is extremely efficient as it runs simulations across multiple agents and multiple environment replicas all in parallel and completely eliminates the back-and-forth data copying between the CPU and the GPU during every step. As such, WarpDrive\n",
+ "- Can simulate 1000s of agents in each environment and thousands of environments in parallel, harnessing the extreme parallelism capability of GPUs.\n",
+ "- Eliminates communication between CPU and GPU, and also within the GPU, as read and write operations occur in-place.\n",
+ "- Is fully compatible with Pytorch, a highly flexible and very fast deep learning framework.\n",
+ "- Implements parallel action sampling on CUDA C, which is ~3x faster than using Pytorch’s sampling methods.\n",
+ "- Allows for large-scale distributed training on multiple GPUs.\n",
+ "\n",
+ "Below is an overview of WarpDrive’s layout of computational and data structures on a single GPU.\n",
+ "\n",
+ "Computations are organized into blocks, with multiple threads in each block. Each block runs a simulation environment and each thread\n",
+ "simulates an agent in an environment. Blocks can access the shared GPU memory that stores simulation data and neural network policy models. A DataManager and FunctionManager enable defining multi-agent RL GPU-workflows with Python APIs. For more details, please read out white [paper](https://arxiv.org/abs/2108.13976).\n",
+ "\n",
+ "The Warpdrive framework comprises several utility functions that help easily implement any (OpenAI-)*gym-style* RL environment, and furthermore, provides quality-of-life tools to train it end-to-end using just a few lines of code. You may familiarize yourself with WarpDrive with the help of these [tutorials](https://github.com/salesforce/warp-drive/tree/master/tutorials).\n",
+ "\n",
+ "We invite everyone to **contribute to WarpDrive**, including adding new multi-agent environments, proposing new features and reporting issues on our open source [repository](https://github.com/salesforce/warp-drive).\n",
+ "\n",
+ "We have integrated WarpDrive with the [Pytorch Lightning](https://www.pytorchlightning.ai/) framework, which greatly reduces the trainer boilerplate code, and improves training modularity and flexibility. It abstracts away most of the engineering pieces of code, so users can focus on research and building models, and iterate on experiments really fast. Pytorch Lightning also provides support for easily running the model on any hardware, performing distributed training, model checkpointing, performance profiling, logging and visualization.\n",
+ "\n",
+ "Below, we demonstrate how to use WarpDrive and PytorchLightning together to train a game of [Tag](https://github.com/salesforce/warp-drive/blob/master/example_envs/tag_continuous/tag_continuous.py) where multiple *tagger* agents are trying to run after and tag multiple other *runner* agents. Here's a sample depiction of the game of Tag with $100$ runners and $5$ taggers.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "977591fe",
+ "metadata": {
+ "papermill": {
+ "duration": 0.402082,
+ "end_time": "2022-05-17T23:32:15.170382",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:14.768300",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Dependencies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "5bcb4efa",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:16.019075Z",
+ "iopub.status.busy": "2022-05-17T23:32:16.018379Z",
+ "iopub.status.idle": "2022-05-17T23:32:36.870713Z",
+ "shell.execute_reply": "2022-05-17T23:32:36.869948Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 21.300039,
+ "end_time": "2022-05-17T23:32:36.872611",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:15.572572",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " setattr(self, word, getattr(machar, word).flat[0])\n",
+ "/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " return self._float_to_str(self.smallest_subnormal)\n",
+ "/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " setattr(self, word, getattr(machar, word).flat[0])\n",
+ "/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " return self._float_to_str(self.smallest_subnormal)\n",
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.\n",
+ " interpolation: int = Image.BILINEAR,\n",
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:296: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.\n",
+ " interpolation: int = Image.NEAREST,\n",
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:329: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.\n",
+ " interpolation: int = Image.BICUBIC,\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/comet_ml/monkey_patching.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses\n",
+ " import imp\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/mlflow/types/schema.py:48: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
+ "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+ " binary = (7, np.dtype(\"bytes\"), \"BinaryType\", np.object)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/sklearn/utils/multiclass.py:14: DeprecationWarning: Please use `spmatrix` from the `scipy.sparse` namespace, the `scipy.sparse.base` namespace is deprecated.\n",
+ " from scipy.sparse.base import spmatrix\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/utils/optimize.py:18: DeprecationWarning: Please use `line_search_wolfe2` from the `scipy.optimize` namespace, the `scipy.optimize.linesearch` namespace is deprecated.\n",
+ " from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1\n",
+ "/usr/local/lib/python3.8/dist-packages/sklearn/utils/optimize.py:18: DeprecationWarning: Please use `line_search_wolfe1` from the `scipy.optimize` namespace, the `scipy.optimize.linesearch` namespace is deprecated.\n",
+ " from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pycuda/compyte/dtypes.py:120: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
+ "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+ " reg.get_or_register_dtype(\"bool\", np.bool)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import logging\n",
+ "\n",
+ "import torch\n",
+ "from example_envs.tag_continuous.tag_continuous import TagContinuous\n",
+ "from pytorch_lightning import Trainer\n",
+ "from warp_drive.env_wrapper import EnvWrapper\n",
+ "from warp_drive.training.pytorch_lightning import CUDACallback, PerfStatsCallback, WarpDriveModule\n",
+ "\n",
+ "# Uncomment below for enabling animation visualizations.\n",
+ "# from example_envs.utils.generate_rollout_animation import generate_tag_env_rollout_animation\n",
+ "# from IPython.display import HTML"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "3cc83a08",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:36.981331Z",
+ "iopub.status.busy": "2022-05-17T23:32:36.980733Z",
+ "iopub.status.idle": "2022-05-17T23:32:36.986904Z",
+ "shell.execute_reply": "2022-05-17T23:32:36.986302Z"
+ },
+ "papermill": {
+ "duration": 0.062486,
+ "end_time": "2022-05-17T23:32:36.988311",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:36.925825",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "assert torch.cuda.device_count() > 0, \"This notebook only runs on a GPU!\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "3a90d200",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:37.096466Z",
+ "iopub.status.busy": "2022-05-17T23:32:37.096037Z",
+ "iopub.status.idle": "2022-05-17T23:32:37.099888Z",
+ "shell.execute_reply": "2022-05-17T23:32:37.099289Z"
+ },
+ "papermill": {
+ "duration": 0.058613,
+ "end_time": "2022-05-17T23:32:37.101324",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:37.042711",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Set logger level e.g., DEBUG, INFO, WARNING, ERROR.\n",
+ "logging.getLogger().setLevel(logging.ERROR)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "737c9d3f",
+ "metadata": {
+ "papermill": {
+ "duration": 0.054713,
+ "end_time": "2022-05-17T23:32:37.208268",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:37.153555",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Specify a set of run configurations for your experiments\n",
+ "\n",
+ "The run configuration is a dictionary comprising the environment parameters, the trainer and the policy network settings, as well as configurations for saving.\n",
+ "\n",
+ "For our experiment, we consider an environment wherein $5$ taggers and $100$ runners play the game of [Tag](https://github.com/salesforce/warp-drive/blob/master/example_envs/tag_continuous/tag_continuous.py) on a $20 \\times 20$ plane. The game lasts $200$ timesteps. Each agent chooses it's own acceleration and turn actions at every timestep, and we use mechanics to determine how the agents move over the grid. When a tagger gets close to a runner, the runner is tagged, and is eliminated from the game. For the configuration below, the runners and taggers have the same unit skill levels, or top speeds.\n",
+ "\n",
+ "We train the agents using $50$ environments or simulations running in parallel. With WarpDrive, each simulation runs on separate GPU blocks.\n",
+ "\n",
+ "There are two separate policy networks used for the tagger and runner agents. Each network is a fully-connected model with two layers each of $256$ dimensions. We use the Advantage Actor Critic (A2C) algorithm for training. WarpDrive also currently provides the option to use the Proximal Policy Optimization (PPO) algorithm instead."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ab5f0721",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:37.315571Z",
+ "iopub.status.busy": "2022-05-17T23:32:37.314914Z",
+ "iopub.status.idle": "2022-05-17T23:32:37.322945Z",
+ "shell.execute_reply": "2022-05-17T23:32:37.322338Z"
+ },
+ "papermill": {
+ "duration": 0.063451,
+ "end_time": "2022-05-17T23:32:37.324426",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:37.260975",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "run_config = dict(\n",
+ " name=\"tag_continuous\",\n",
+ " # Environment settings.\n",
+ " env=dict(\n",
+ " # number of taggers in the environment\n",
+ " num_taggers=5,\n",
+ " # number of runners in the environment\n",
+ " num_runners=100,\n",
+ " # length of the (square) grid on which the game is played\n",
+ " grid_length=20.0,\n",
+ " # episode length in timesteps\n",
+ " episode_length=200,\n",
+ " # maximum acceleration\n",
+ " max_acceleration=0.1,\n",
+ " # minimum acceleration\n",
+ " min_acceleration=-0.1,\n",
+ " # maximum turn (in radians)\n",
+ " max_turn=2.35, # 3pi/4 radians\n",
+ " # minimum turn (in radians)\n",
+ " min_turn=-2.35, # -3pi/4 radians\n",
+ " # number of discretized accelerate actions\n",
+ " num_acceleration_levels=10,\n",
+ " # number of discretized turn actions\n",
+ " num_turn_levels=10,\n",
+ " # skill level for the tagger\n",
+ " skill_level_tagger=1.0,\n",
+ " # skill level for the runner\n",
+ " skill_level_runner=1.0,\n",
+ " # each agent sees the full (or partial) information of the world\n",
+ " use_full_observation=False,\n",
+ " # flag to indicate if a runner stays in the game after getting tagged\n",
+ " runner_exits_game_after_tagged=True,\n",
+ " # number of other agents each agent can see\n",
+ " # used in the case use_full_observation is False\n",
+ " num_other_agents_observed=10,\n",
+ " # positive reward for a tagger upon tagging a runner\n",
+ " tag_reward_for_tagger=10.0,\n",
+ " # negative reward for a runner upon getting tagged\n",
+ " tag_penalty_for_runner=-10.0,\n",
+ " # reward at the end of the game for a runner that isn't tagged\n",
+ " end_of_game_reward_for_runner=1.0,\n",
+ " # distance margin between a tagger and runner\n",
+ " # to consider the runner as being 'tagged'\n",
+ " tagging_distance=0.02,\n",
+ " ),\n",
+ " # Trainer settings.\n",
+ " trainer=dict(\n",
+ " # number of environment replicas (number of GPU blocks used)\n",
+ " num_envs=50,\n",
+ " # total batch size used for training per iteration (across all the environments)\n",
+ " train_batch_size=10000,\n",
+ " # total number of episodes to run the training for\n",
+ " # This can be set arbitrarily high!\n",
+ " num_episodes=500,\n",
+ " ),\n",
+ " # Policy network settings.\n",
+ " policy=dict(\n",
+ " runner=dict(\n",
+ " # flag indicating whether the model needs to be trained\n",
+ " to_train=True,\n",
+ " # algorithm used to train the policy\n",
+ " algorithm=\"A2C\",\n",
+ " # discount rate\n",
+ " gamma=0.98,\n",
+ " # learning rate\n",
+ " lr=0.005,\n",
+ " # policy model settings\n",
+ " model=dict(type=\"fully_connected\", fc_dims=[256, 256], model_ckpt_filepath=\"\"),\n",
+ " ),\n",
+ " tagger=dict(\n",
+ " to_train=True,\n",
+ " algorithm=\"A2C\",\n",
+ " gamma=0.98,\n",
+ " lr=0.002,\n",
+ " model=dict(type=\"fully_connected\", fc_dims=[256, 256], model_ckpt_filepath=\"\"),\n",
+ " ),\n",
+ " ),\n",
+ " # Checkpoint saving setting.\n",
+ " saving=dict(\n",
+ " # how often (in iterations) to print the metrics\n",
+ " metrics_log_freq=10,\n",
+ " # how often (in iterations) to save the model parameters\n",
+ " model_params_save_freq=5000,\n",
+ " # base folder used for saving\n",
+ " basedir=\"/tmp\",\n",
+ " # experiment name\n",
+ " name=\"continuous_tag\",\n",
+ " # experiment tag\n",
+ " tag=\"example\",\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4ca733ab",
+ "metadata": {
+ "papermill": {
+ "duration": 0.054553,
+ "end_time": "2022-05-17T23:32:37.435282",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:37.380729",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Instantiate the WarpDrive Module\n",
+ "\n",
+ "In order to instantiate the WarpDrive module, we first use an environment wrapper to specify that the environment needs to be run on the GPU (via the `use_cuda` flag). Also, agents in the environment can share policy models; so we specify a dictionary to map each policy network model to the list of agent ids using that model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "92adbc29",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:37.543960Z",
+ "iopub.status.busy": "2022-05-17T23:32:37.543329Z",
+ "iopub.status.idle": "2022-05-17T23:32:49.152811Z",
+ "shell.execute_reply": "2022-05-17T23:32:49.152128Z"
+ },
+ "lines_to_next_cell": 2,
+ "papermill": {
+ "duration": 11.66556,
+ "end_time": "2022-05-17T23:32:49.154310",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:37.488750",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Global seed set to 1652830369\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create a wrapped environment object via the EnvWrapper.\n",
+ "# Ensure that use_cuda is set to True (in order to run on the GPU).\n",
+ "env_wrapper = EnvWrapper(\n",
+ " TagContinuous(**run_config[\"env\"]),\n",
+ " num_envs=run_config[\"trainer\"][\"num_envs\"],\n",
+ " use_cuda=True,\n",
+ ")\n",
+ "\n",
+ "# Agents can share policy models: this dictionary maps policy model names to agent ids.\n",
+ "policy_tag_to_agent_id_map = {\n",
+ " \"tagger\": list(env_wrapper.env.taggers),\n",
+ " \"runner\": list(env_wrapper.env.runners),\n",
+ "}\n",
+ "\n",
+ "wd_module = WarpDriveModule(\n",
+ " env_wrapper=env_wrapper,\n",
+ " config=run_config,\n",
+ " policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n",
+ " verbose=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1afa2a55",
+ "metadata": {
+ "papermill": {
+ "duration": 0.05218,
+ "end_time": "2022-05-17T23:32:49.258040",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:49.205860",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Visualizing an episode roll-out before training\n",
+ "\n",
+ "We have created a helper function (see below) to visualize an episode rollout. Internally, this function uses the WarpDrive module's `fetch_episode_states` API to fetch the data arrays on the GPU for the duration of an entire episode. Specifically, we fetch the state arrays pertaining to agents' x and y locations on the plane and indicators on which agents are still active in the game. Note that this function may be invoked at any time during training, and it will use the state of the policy models at that time to sample actions and generate the visualization."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0fc02c61",
+ "metadata": {
+ "papermill": {
+ "duration": 0.051218,
+ "end_time": "2022-05-17T23:32:49.361244",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:49.310026",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "The animation below shows a sample realization of the game episode before training, i.e., with randomly chosen agent actions. The $5$ taggers are marked in pink, while the $100$ blue agents are the runners. Both the taggers and runners move around randomly and about half the runners remain at the end of the episode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "83ee0c27",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:49.465764Z",
+ "iopub.status.busy": "2022-05-17T23:32:49.465325Z",
+ "iopub.status.idle": "2022-05-17T23:32:49.468831Z",
+ "shell.execute_reply": "2022-05-17T23:32:49.468240Z"
+ },
+ "papermill": {
+ "duration": 0.057836,
+ "end_time": "2022-05-17T23:32:49.470284",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:49.412448",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Uncomment below for enabling animation visualizations.\n",
+ "# anim = generate_tag_env_rollout_animation(wd_module, fps=25)\n",
+ "# HTML(anim.to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7fb4c051",
+ "metadata": {
+ "papermill": {
+ "duration": 0.051534,
+ "end_time": "2022-05-17T23:32:49.573348",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:49.521814",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Create the Lightning Trainer\n",
+ "\n",
+ "Next, we create the trainer for training the WarpDrive model. We add the `performance stats` callbacks to the trainer to view the throughput performance of WarpDrive."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "3c024365",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:49.680332Z",
+ "iopub.status.busy": "2022-05-17T23:32:49.679501Z",
+ "iopub.status.idle": "2022-05-17T23:32:49.690576Z",
+ "shell.execute_reply": "2022-05-17T23:32:49.689994Z"
+ },
+ "papermill": {
+ "duration": 0.065403,
+ "end_time": "2022-05-17T23:32:49.692022",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:49.626619",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "GPU available: True, used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ }
+ ],
+ "source": [
+ "log_freq = run_config[\"saving\"][\"metrics_log_freq\"]\n",
+ "\n",
+ "# Define callbacks.\n",
+ "cuda_callback = CUDACallback(module=wd_module)\n",
+ "perf_stats_callback = PerfStatsCallback(\n",
+ " batch_size=wd_module.training_batch_size,\n",
+ " num_iters=wd_module.num_iters,\n",
+ " log_freq=log_freq,\n",
+ ")\n",
+ "\n",
+ "# Instantiate the PytorchLightning trainer with the callbacks.\n",
+ "# Also, set the number of gpus to 1, since this notebook uses just a single GPU.\n",
+ "num_gpus = 1\n",
+ "num_episodes = run_config[\"trainer\"][\"num_episodes\"]\n",
+ "episode_length = run_config[\"env\"][\"episode_length\"]\n",
+ "training_batch_size = run_config[\"trainer\"][\"train_batch_size\"]\n",
+ "num_epochs = num_episodes * episode_length / training_batch_size\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " accelerator=\"gpu\",\n",
+ " devices=num_gpus,\n",
+ " callbacks=[cuda_callback, perf_stats_callback],\n",
+ " max_epochs=num_epochs,\n",
+ " log_every_n_steps=1,\n",
+ " reload_dataloaders_every_n_epochs=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "8515b530",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:49.804904Z",
+ "iopub.status.busy": "2022-05-17T23:32:49.804426Z",
+ "iopub.status.idle": "2022-05-17T23:32:51.387677Z",
+ "shell.execute_reply": "2022-05-17T23:32:51.386870Z"
+ },
+ "papermill": {
+ "duration": 1.639884,
+ "end_time": "2022-05-17T23:32:51.389188",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:49.749304",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Start tensorboard.\n",
+ "%load_ext tensorboard\n",
+ "%tensorboard --logdir lightning_logs/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1636a7c",
+ "metadata": {
+ "papermill": {
+ "duration": 0.058623,
+ "end_time": "2022-05-17T23:32:51.506077",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:51.447454",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Train the WarpDrive Module\n",
+ "\n",
+ "Finally, we invoke training.\n",
+ "\n",
+ "Note: please scroll up to the tensorboard cell to visualize the curves during training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "11ef667b",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:32:51.624927Z",
+ "iopub.status.busy": "2022-05-17T23:32:51.624490Z",
+ "iopub.status.idle": "2022-05-17T23:33:00.758063Z",
+ "shell.execute_reply": "2022-05-17T23:33:00.757403Z"
+ },
+ "papermill": {
+ "duration": 9.195354,
+ "end_time": "2022-05-17T23:33:00.759654",
+ "exception": false,
+ "start_time": "2022-05-17T23:32:51.564300",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:376: LightningDeprecationWarning: The `Callback.on_batch_start` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_train_batch_start` instead.\n",
+ " rank_zero_deprecation(\n",
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:376: LightningDeprecationWarning: The `Callback.on_batch_end` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_train_batch_end` instead.\n",
+ " rank_zero_deprecation(\n",
+ "Missing logger folder: /__w/1/s/lightning_logs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " | Name | Type | Params\n",
+ "------------------------------\n",
+ "------------------------------\n",
+ "0 Trainable params\n",
+ "0 Non-trainable params\n",
+ "0 Total params\n",
+ "0.000 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " rank_zero_warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cbe0355f05164620a246f0b89eb0d301",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "========================================\n",
+ "Metrics for policy 'runner'\n",
+ "========================================\n",
+ "VF loss coefficient : 0.01000\n",
+ "Entropy coefficient : 0.05000\n",
+ "Total loss : -1.51269\n",
+ "Policy loss : -1.31748\n",
+ "Value function loss : 4.30106\n",
+ "Mean rewards : -0.02525\n",
+ "Max. rewards : 1.00000\n",
+ "Min. rewards : -10.00000\n",
+ "Mean value function : -0.86170\n",
+ "Mean advantages : -0.27768\n",
+ "Mean (norm.) advantages : -0.27768\n",
+ "Mean (discounted) returns : -1.13938\n",
+ "Mean normalized returns : -1.13938\n",
+ "Mean entropy : 4.76451\n",
+ "Variance explained by the value function: 0.11032\n",
+ "Std. of action_0 over agents : 3.04816\n",
+ "Std. of action_0 over envs : 3.04446\n",
+ "Std. of action_0 over time : 3.04757\n",
+ "Std. of action_1 over agents : 3.23549\n",
+ "Std. of action_1 over envs : 3.23271\n",
+ "Std. of action_1 over time : 3.23722\n",
+ "Current timestep : 90000.00000\n",
+ "Gradient norm : 0.05845\n",
+ "Mean episodic reward : -408.38889\n",
+ "[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json' \n",
+ "[Device 0]: Saving the 'runner' torch model to the file: '/tmp/continuous_tag/example/1652830363/runner_90000.state_dict'. \n",
+ "[Device 0]: Saving the 'tagger' torch model to the file: '/tmp/continuous_tag/example/1652830363/tagger_80000.state_dict'. \n",
+ "========================================\n",
+ "Metrics for policy 'tagger'\n",
+ "========================================\n",
+ "VF loss coefficient : 0.01000\n",
+ "Entropy coefficient : 0.05000\n",
+ "Total loss : 79.46014\n",
+ "Policy loss : 75.07774\n",
+ "Value function loss : 460.96414\n",
+ "Mean rewards : 0.53500\n",
+ "Max. rewards : 20.00000\n",
+ "Min. rewards : 0.00000\n",
+ "Mean value function : 3.43005\n",
+ "Mean advantages : 16.50640\n",
+ "Mean (norm.) advantages : 16.50640\n",
+ "Mean (discounted) returns : 19.93644\n",
+ "Mean normalized returns : 19.93644\n",
+ "Mean entropy : 4.54485\n",
+ "Variance explained by the value function: -0.00764\n",
+ "Std. of action_0 over agents : 3.04688\n",
+ "Std. of action_0 over envs : 3.19368\n",
+ "Std. of action_0 over time : 3.19806\n",
+ "Std. of action_1 over agents : 2.74155\n",
+ "Std. of action_1 over envs : 2.85016\n",
+ "Std. of action_1 over time : 2.85594\n",
+ "Current timestep : 90000.00000\n",
+ "Gradient norm : 1.21257\n",
+ "Mean episodic reward : 449.24444\n",
+ "[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json' \n",
+ "[Device 0]: Saving the 'runner' torch model to the file: '/tmp/continuous_tag/example/1652830363/runner_90000.state_dict'. \n",
+ "[Device 0]: Saving the 'tagger' torch model to the file: '/tmp/continuous_tag/example/1652830363/tagger_90000.state_dict'. \n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('Current timestep_runner', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.\n",
+ " warning_cache.warn(\n",
+ "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('Current timestep_tagger', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.\n",
+ " warning_cache.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "========================================\n",
+ "Metrics for policy 'runner'\n",
+ "========================================\n",
+ "VF loss coefficient : 0.01000\n",
+ "Entropy coefficient : 0.05000\n",
+ "Total loss : -1.06076\n",
+ "Policy loss : -0.86573\n",
+ "Value function loss : 4.28389\n",
+ "Mean rewards : -0.02681\n",
+ "Max. rewards : 1.00000\n",
+ "Min. rewards : -10.00000\n",
+ "Mean value function : -1.03110\n",
+ "Mean advantages : -0.18345\n",
+ "Mean (norm.) advantages : -0.18345\n",
+ "Mean (discounted) returns : -1.21455\n",
+ "Mean normalized returns : -1.21455\n",
+ "Mean entropy : 4.75726\n",
+ "Variance explained by the value function: 0.13849\n",
+ "Std. of action_0 over agents : 3.08665\n",
+ "Std. of action_0 over envs : 3.08295\n",
+ "Std. of action_0 over time : 3.08616\n",
+ "Std. of action_1 over agents : 3.21539\n",
+ "Std. of action_1 over envs : 3.21178\n",
+ "Std. of action_1 over time : 3.21630\n",
+ "Current timestep : 100000.00000\n",
+ "Gradient norm : 0.05899\n",
+ "Mean episodic reward : -536.14000\n",
+ "[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json' \n",
+ "========================================\n",
+ "Metrics for policy 'tagger'\n",
+ "========================================\n",
+ "VF loss coefficient : 0.01000\n",
+ "Entropy coefficient : 0.05000\n",
+ "Total loss : 77.55455\n",
+ "Policy loss : 72.94509\n",
+ "Value function loss : 482.91556\n",
+ "Mean rewards : 0.56020\n",
+ "Max. rewards : 20.00000\n",
+ "Min. rewards : 0.00000\n",
+ "Mean value function : 4.44337\n",
+ "Mean advantages : 16.58761\n",
+ "Mean (norm.) advantages : 16.58761\n",
+ "Mean (discounted) returns : 21.03099\n",
+ "Mean normalized returns : 21.03099\n",
+ "Mean entropy : 4.39390\n",
+ "Variance explained by the value function: -0.00993\n",
+ "Std. of action_0 over agents : 2.94368\n",
+ "Std. of action_0 over envs : 3.11596\n",
+ "Std. of action_0 over time : 3.12263\n",
+ "Std. of action_1 over agents : 2.66070\n",
+ "Std. of action_1 over envs : 2.78366\n",
+ "Std. of action_1 over time : 2.79009\n",
+ "Current timestep : 100000.00000\n",
+ "Gradient norm : 1.13135\n",
+ "Mean episodic reward : 560.20000\n",
+ "[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json' \n",
+ "========================================\n",
+ "Speed performance stats\n",
+ "========================================\n",
+ "Iteration : 10 / 10 \n",
+ "Mean training time per iter (ms) : 131.28\n",
+ "Mean steps per sec (training time) : 76172.00\n",
+ "\n",
+ "\n",
+ "Training is complete!\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainer.fit(wd_module)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5b4259b0",
+ "metadata": {
+ "papermill": {
+ "duration": 0.056368,
+ "end_time": "2022-05-17T23:33:00.876445",
+ "exception": false,
+ "start_time": "2022-05-17T23:33:00.820077",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Visualize an episode-rollout after training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "dcac436c",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:33:00.995986Z",
+ "iopub.status.busy": "2022-05-17T23:33:00.995563Z",
+ "iopub.status.idle": "2022-05-17T23:33:00.998889Z",
+ "shell.execute_reply": "2022-05-17T23:33:00.998313Z"
+ },
+ "papermill": {
+ "duration": 0.064773,
+ "end_time": "2022-05-17T23:33:01.000323",
+ "exception": false,
+ "start_time": "2022-05-17T23:33:00.935550",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Uncomment below for enabling animation visualizations.\n",
+ "# anim = generate_tag_env_rollout_animation(wd_module, fps=25)\n",
+ "# HTML(anim.to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b5c4e079",
+ "metadata": {
+ "papermill": {
+ "duration": 0.0599,
+ "end_time": "2022-05-17T23:33:01.119188",
+ "exception": false,
+ "start_time": "2022-05-17T23:33:01.059288",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "Note: In the configuration above, we have set the trainer to only train on $500$ rollout episodes, but you can increase the `num_episodes` configuration parameter to train further. As more training happens, the runners learn to escape the taggers, and the taggers learn to chase after the runner. Sometimes, the taggers also collaborate to team-tag runners. A good number of episodes to train on (for the configuration we have used) is $2$M or higher."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "ee7d9a19",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2022-05-17T23:33:01.240893Z",
+ "iopub.status.busy": "2022-05-17T23:33:01.240028Z",
+ "iopub.status.idle": "2022-05-17T23:33:01.244499Z",
+ "shell.execute_reply": "2022-05-17T23:33:01.243851Z"
+ },
+ "papermill": {
+ "duration": 0.066518,
+ "end_time": "2022-05-17T23:33:01.245965",
+ "exception": false,
+ "start_time": "2022-05-17T23:33:01.179447",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Device 0]: Trainer exits gracefully \n"
+ ]
+ }
+ ],
+ "source": [
+ "# Finally, close the WarpDrive module to clear up the CUDA memory heap\n",
+ "wd_module.graceful_close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4bb1540",
+ "metadata": {
+ "papermill": {
+ "duration": 0.055238,
+ "end_time": "2022-05-17T23:33:01.360822",
+ "exception": false,
+ "start_time": "2022-05-17T23:33:01.305584",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Congratulations - Time to Join the Community!\n",
+ "\n",
+ "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n",
+ "movement, you can do so in the following ways!\n",
+ "\n",
+ "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
+ "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n",
+ "tools we're building.\n",
+ "\n",
+ "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n",
+ "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n",
+ "and share your interests in `#general` channel\n",
+ "\n",
+ "\n",
+ "### Contributions !\n",
+ "The best way to contribute to our community is to become a code contributor! At any time you can go to\n",
+ "[Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n",
+ "GitHub Issues page and filter for \"good first issue\".\n",
+ "\n",
+ "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* You can also contribute your own notebooks with useful examples !\n",
+ "\n",
+ "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
+ "\n",
+ "[{height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "id,colab,colab_type,-all",
+ "formats": "ipynb,py:percent",
+ "main_language": "python"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ },
+ "papermill": {
+ "default_parameters": {},
+ "duration": 59.679822,
+ "end_time": "2022-05-17T23:33:04.340472",
+ "environment_variables": {},
+ "exception": null,
+ "input_path": "lightning_examples/warp-drive/multi_agent_rl.ipynb",
+ "output_path": ".notebooks/lightning_examples/warp-drive.ipynb",
+ "parameters": {},
+ "start_time": "2022-05-17T23:32:04.660650",
+ "version": "2.3.4"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {
+ "0729ac1b08fc45debfb674e70712987d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_4fddedb84bdd4344ba7545999cce0902",
+ "placeholder": "",
+ "style": "IPY_MODEL_819914d87a4741a8b96fd75b6c174c55",
+ "value": " 1/1 [00:00<00:00, 1.14it/s, loss=37, v_num=0, loss_runner=-1.06, loss_tagger=77.60]"
+ }
+ },
+ "2fd8c257870c447f93d048e20048d9ae": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "43850dc51bf042fbb47c45f7faf21af8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_acdd1b6a8a0440429dfe3a356525642f",
+ "placeholder": "",
+ "style": "IPY_MODEL_2fd8c257870c447f93d048e20048d9ae",
+ "value": "Epoch 9: 100%"
+ }
+ },
+ "4fddedb84bdd4344ba7545999cce0902": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "819914d87a4741a8b96fd75b6c174c55": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "acdd1b6a8a0440429dfe3a356525642f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c35131815ad744ca90087abdede11116": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d566a65627a54c5a9b1aded8ceab518d",
+ "max": 1.0,
+ "min": 0.0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_e869975f28ac4bbebb02fbe32759546e",
+ "value": 1.0
+ }
+ },
+ "cbe0355f05164620a246f0b89eb0d301": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_43850dc51bf042fbb47c45f7faf21af8",
+ "IPY_MODEL_c35131815ad744ca90087abdede11116",
+ "IPY_MODEL_0729ac1b08fc45debfb674e70712987d"
+ ],
+ "layout": "IPY_MODEL_dd1a2546a7814b70b43a86624d0e5c43"
+ }
+ },
+ "d566a65627a54c5a9b1aded8ceab518d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": "2",
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd1a2546a7814b70b43a86624d0e5c43": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": "inline-flex",
+ "flex": null,
+ "flex_flow": "row wrap",
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "100%"
+ }
+ },
+ "e869975f28ac4bbebb02fbe32759546e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ }
+ },
+ "version_major": 2,
+ "version_minor": 0
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/.notebooks/lightning_examples/warp-drive.yaml b/.notebooks/lightning_examples/warp-drive.yaml
new file mode 100644
index 000000000..aa25415ee
--- /dev/null
+++ b/.notebooks/lightning_examples/warp-drive.yaml
@@ -0,0 +1,41 @@
+title: Multi-agent Reinforcement Learning With WarpDrive
+author: Sunil Srinivasa (sunil.srinivasa@salesforce.com), Tian Lan (tian.lan@salesforce.com),
+ Huan Wang (huan.wang@salesforce.com) and Stephan Zheng(stephan.zheng@salesforce.com)
+created: 2022-03-01
+license: BSD 3-Clause "New" or "Revised" License
+tags:
+ - Reinforcement Learning
+ - Multi-agent
+ - GPU
+description: "This notebook introduces multi-agent reinforcement learning (MARL) with
+ WarpDrive (Lan et al. https://arxiv.org/abs/2108.13976). WarpDrive is a flexible,
+ lightweight, and easy-to-use open-source framework that implements end-to-end deep
+ MARL on GPUs. WarpDrive enables orders-of-magnitude speedups compared to CPU-GPU
+ implementations, using the parallelization capability of GPUs and several design
+ choices to minimize communication overhead. WarpDrive also prioritizes user-friendliness
+ - it has utility functions to easily build MARL environments in CUDA and quality-of-life
+ tools to run end-to-end MARL using just a few lines of code, and is compatible with
+ PyTorch.
+
+ WarpDrive includes the following resources. code - https://github.com/salesforce/warp-drive
+ documentation - http://opensource.salesforce.com/warp-drive/, and white paper -
+ https://arxiv.org/abs/2108.13976."
+requirements:
+ - rl-warp-drive>=1.6.5
+ - ffmpeg-python
+ - torch==1.10.*
+ - torchvision==0.11.*
+ - torchtext==0.11.*
+accelerator:
+ - GPU
+environment:
+ - setuptools==59.5.0
+ - torchmetrics==0.7.2
+ - torch==1.10.2+cu111
+ - ffmpeg-python==0.2.0
+ - ipython==8.1.1
+ - rl-warp-drive==1.6.5
+ - torchtext==0.11.2
+ - torchvision==0.11.3+cu111
+ - pytorch-lightning==1.6.3
+published: "2022-05-18T01:33:06.584850"
diff --git a/.notebooks/sample-template.png b/.notebooks/sample-template.png
new file mode 100644
index 000000000..e778f650d
Binary files /dev/null and b/.notebooks/sample-template.png differ
diff --git a/.notebooks/templates/img-classify.ipynb b/.notebooks/templates/img-classify.ipynb
new file mode 100644
index 000000000..783ca6fdc
--- /dev/null
+++ b/.notebooks/templates/img-classify.ipynb
@@ -0,0 +1,3362 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "a871843b",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005577,
+ "end_time": "2023-01-05T12:54:51.487657",
+ "exception": false,
+ "start_time": "2023-01-05T12:54:51.482080",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "\n",
+ "# Simple image classification with Lightning Flash\n",
+ "\n",
+ "* **Author:** PL team\n",
+ "* **License:** CC BY-SA\n",
+ "* **Generated:** 2023-01-05T13:50:53.263007\n",
+ "\n",
+ "This is a template to show simple image classification case if for some reason accelerator is required.\n",
+ "\n",
+ "\n",
+ "---\n",
+ "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/templates/img-classify.ipynb)\n",
+ "\n",
+ "Give us a ⭐ [on Github](https://www.github.com/Lightning-AI/lightning/)\n",
+ "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n",
+ "| Join us [on Slack](https://www.pytorchlightning.ai/community)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1f4252b5",
+ "metadata": {
+ "papermill": {
+ "duration": 0.002017,
+ "end_time": "2023-01-05T12:54:51.494790",
+ "exception": false,
+ "start_time": "2023-01-05T12:54:51.492773",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## Setup\n",
+ "This notebook requires some packages besides pytorch-lightning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "0f45212e",
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "execution": {
+ "iopub.execute_input": "2023-01-05T12:54:51.500626Z",
+ "iopub.status.busy": "2023-01-05T12:54:51.500272Z",
+ "iopub.status.idle": "2023-01-05T12:54:55.843123Z",
+ "shell.execute_reply": "2023-01-05T12:54:55.841656Z"
+ },
+ "id": "LfrJLKPFyhsK",
+ "lines_to_next_cell": 0,
+ "papermill": {
+ "duration": 4.348882,
+ "end_time": "2023-01-05T12:54:55.845928",
+ "exception": false,
+ "start_time": "2023-01-05T12:54:51.497046",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[31mERROR: Cannot install lightning-flash[image]==0.7.0, lightning-flash[image]==0.7.1, lightning-flash[image]==0.7.2, lightning-flash[image]==0.7.3, lightning-flash[image]==0.7.4, lightning-flash[image]==0.7.5, lightning-flash[image]==0.8.0, lightning-flash[image]==0.8.1, lightning-flash[image]==0.8.1.post0 and setuptools==65.6.3 because these package versions have conflicting dependencies.\u001b[0m\u001b[31m\r\n",
+ "\u001b[0m\u001b[31mERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts\u001b[0m\u001b[31m\r\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "! pip install --quiet \"matplotlib>=3.0\" \"numpy<1.24\" \"lightning-flash[image]>=0.7\" \"pandas>=1.0\" \"seaborn\" \"torchmetrics>=0.7, <0.12\" \"setuptools==65.6.3\" \"pytorch-lightning>=1.4, <1.9\" \"ipython[notebook]>=8.0.0, <8.9.0\" \"torch>=1.8.1, <1.14.0\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "62428f5f",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T12:54:55.856737Z",
+ "iopub.status.busy": "2023-01-05T12:54:55.856369Z",
+ "iopub.status.idle": "2023-01-05T12:55:03.503933Z",
+ "shell.execute_reply": "2023-01-05T12:55:03.502614Z"
+ },
+ "papermill": {
+ "duration": 7.655352,
+ "end_time": "2023-01-05T12:55:03.506602",
+ "exception": false,
+ "start_time": "2023-01-05T12:54:55.851250",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pl_bolts/callbacks/data_monitor.py:20: UnderReviewWarning: The feature warn_missing_pkg is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html\n",
+ " warn_missing_pkg(\"wandb\")\n",
+ "/usr/local/lib/python3.9/dist-packages/pl_bolts/models/self_supervised/amdim/amdim_module.py:35: UnderReviewWarning: The feature generate_power_seq is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html\n",
+ " \"lr_options\": generate_power_seq(LEARNING_RATE_CIFAR, 11),\n",
+ "/usr/local/lib/python3.9/dist-packages/pl_bolts/models/self_supervised/amdim/amdim_module.py:93: UnderReviewWarning: The feature FeatureMapContrastiveTask is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html\n",
+ " contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask(\"01, 02, 11\"),\n",
+ "/usr/local/lib/python3.9/dist-packages/pl_bolts/losses/self_supervised_learning.py:234: UnderReviewWarning: The feature AmdimNCELoss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html\n",
+ " self.nce_loss = AmdimNCELoss(tclip)\n",
+ "/usr/local/lib/python3.9/dist-packages/pl_bolts/datamodules/experience_source.py:18: UnderReviewWarning: The feature warn_missing_pkg is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html\n",
+ " warn_missing_pkg(\"gym\")\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:252: UserWarning: Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may be removed in the future. Please access them via the appropriate Weights Enum instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_3082/3275308287.py:8: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
+ " from IPython.core.display import display\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "\n",
+ "import flash\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "import seaborn as sn\n",
+ "from flash.image import ImageClassificationData, ImageClassifier\n",
+ "from IPython.core.display import display\n",
+ "from pytorch_lightning.loggers import CSVLogger\n",
+ "\n",
+ "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n",
+ "# this dataset is automatically downloaded and extracted based on meta link\n",
+ "# this archive includes the one more level - folder with the same name\n",
+ "DATA_HYMENOPLERA = os.path.join(PATH_DATASETS, \"hymenoptera_data\", \"hymenoptera_data\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ecea04c9",
+ "metadata": {
+ "papermill": {
+ "duration": 0.002871,
+ "end_time": "2023-01-05T12:55:03.514965",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:03.512094",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## 1. Create the DataModule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "2efb0103",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T12:55:03.522508Z",
+ "iopub.status.busy": "2023-01-05T12:55:03.521217Z",
+ "iopub.status.idle": "2023-01-05T12:55:03.536817Z",
+ "shell.execute_reply": "2023-01-05T12:55:03.535783Z"
+ },
+ "papermill": {
+ "duration": 0.021881,
+ "end_time": "2023-01-05T12:55:03.539155",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:03.517274",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py:3442: FutureWarning: Please pass an instantiated object of the `InputTransform` class. Passing the Class and keyword arguments separately has been deprecated since v0.8.0 and will be removed in v0.9.0.\n",
+ " exec(code_obj, self.user_global_ns, self.user_ns)\n",
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/apply_func.py:31: LightningDeprecationWarning: `pytorch_lightning.utilities.apply_func.apply_to_collection` has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use `lightning_utilities.core.apply_func.apply_to_collection` instead.\n",
+ " rank_zero_deprecation(\n"
+ ]
+ }
+ ],
+ "source": [
+ "datamodule = ImageClassificationData.from_folders(\n",
+ " train_folder=f\"{DATA_HYMENOPLERA}/train/\",\n",
+ " val_folder=f\"{DATA_HYMENOPLERA}/val/\",\n",
+ " batch_size=1024,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1fef4061",
+ "metadata": {
+ "papermill": {
+ "duration": 0.005107,
+ "end_time": "2023-01-05T12:55:03.549464",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:03.544357",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## 2. Build the task"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9ce86d71",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T12:55:03.556038Z",
+ "iopub.status.busy": "2023-01-05T12:55:03.555663Z",
+ "iopub.status.idle": "2023-01-05T12:55:06.308050Z",
+ "shell.execute_reply": "2023-01-05T12:55:06.306853Z"
+ },
+ "papermill": {
+ "duration": 2.758147,
+ "end_time": "2023-01-05T12:55:06.310384",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:03.552237",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2f025b7cfecd4047a4015ee418077219",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0.00/44.7M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model = ImageClassifier(backbone=\"resnet18\", labels=datamodule.labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8c225450",
+ "metadata": {
+ "papermill": {
+ "duration": 0.002742,
+ "end_time": "2023-01-05T12:55:06.319422",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:06.316680",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## 3. Create the trainer and finetune the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "e3e2e789",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T12:55:06.325924Z",
+ "iopub.status.busy": "2023-01-05T12:55:06.325523Z",
+ "iopub.status.idle": "2023-01-05T12:55:16.133772Z",
+ "shell.execute_reply": "2023-01-05T12:55:16.132799Z"
+ },
+ "papermill": {
+ "duration": 9.814424,
+ "end_time": "2023-01-05T12:55:16.136412",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:06.321988",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:441: LightningDeprecationWarning: Setting `Trainer(gpus=1)` is deprecated in v1.7 and will be removed in v2.0. Please use `Trainer(accelerator='gpu', devices=1)` instead.\n",
+ " rank_zero_deprecation(\n",
+ "GPU available: True (cuda), used: True\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TPU available: False, using: 0 TPU cores\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "IPU available: False, using: 0 IPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Missing logger folder: logs/lightning_logs\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2,3]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " | Name | Type | Params\n",
+ "-------------------------------------------------\n",
+ "0 | train_metrics | ModuleDict | 0 \n",
+ "1 | val_metrics | ModuleDict | 0 \n",
+ "2 | test_metrics | ModuleDict | 0 \n",
+ "3 | adapter | DefaultAdapter | 11.2 M\n",
+ "-------------------------------------------------\n",
+ "10.6 K Trainable params\n",
+ "11.2 M Non-trainable params\n",
+ "11.2 M Total params\n",
+ "44.710 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bfa47955216e4933b85ccd2b4afc4b35",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " rank_zero_warn(\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " rank_zero_warn(\n",
+ "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py:1595: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
+ " rank_zero_warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "29fe61b61a00477aa9164485896b7d73",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0fa34b020e1e4bf98928fd3bcdc4e2e5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "90ca39717e1b444eb074af9a45feb342",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "58e70e73e9dd4d9992d9a2c0aa3a0677",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "logger = CSVLogger(save_dir=\"logs/\")\n",
+ "trainer = flash.Trainer(logger=logger, max_epochs=3, gpus=1)\n",
+ "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "00c5554d",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-05T12:55:16.151546Z",
+ "iopub.status.busy": "2023-01-05T12:55:16.150921Z",
+ "iopub.status.idle": "2023-01-05T12:55:16.518232Z",
+ "shell.execute_reply": "2023-01-05T12:55:16.517123Z"
+ },
+ "papermill": {
+ "duration": 0.374885,
+ "end_time": "2023-01-05T12:55:16.520646",
+ "exception": false,
+ "start_time": "2023-01-05T12:55:16.145761",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "