diff --git a/examples/art/GLM-Classification_final.ipynb b/examples/art/GLM-Classification_final.ipynb new file mode 100644 index 0000000..8e7093e --- /dev/null +++ b/examples/art/GLM-Classification_final.ipynb @@ -0,0 +1,746 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9bdcf10d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sat Jan 14 16:24:32 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 470.129.06 Driver Version: 470.129.06 CUDA Version: 11.5 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla V100-SXM2... Off | 00000000:65:01.0 Off | 0 |\n", + "| N/A 36C P0 37W / 300W | 0MiB / 32510MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f252f5e7", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from scipy.linalg import block_diag\n", + "from typing import List\n", + "from tqdm.notebook import tqdm\n", + "import numpy as np\n", + "from datasets import load_dataset\n", + "from promptsource.templates import DatasetTemplates\n", + "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig, get_linear_schedule_with_warmup\n", + "from sklearn.metrics import accuracy_score\n", + "import matplotlib.pyplot as plt\n", + "import logging\n", + "import sys\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0eee992e", + "metadata": {}, + "outputs": [], + "source": [ + "# For the classification task, in a Seq2Seq model like GLM, we need to calculate the conditional probability of choices for the given context.\n", + "# Remember to refer to code example (https://github.com/THUDM/GLM#classification) in GLM's repo.\n", + "\n", + "# The `cond_log_prob` could be used for both multiple-choice problem (i.e., classification) or text generation (i.e., summurization).\n", + "def cond_log_prob_single_sample(context, choices):\n", + " \"\"\"\n", + " Compute conditonal probability for one or more continuation/infilling options, single-sample only.\n", + " General solution to all classification/multiple-choice tasks.\n", + " :param context: prompted inputs. For example, \"One plus one equals two, is it correct? Answer: [MASK]\"\n", + " :param choices: classification labels or choices. For example, [\"No\", \"Yes\"]\n", + " \"\"\"\n", + " context_id = tokenizer(context)['input_ids']\n", + " probs = []\n", + " for choice in choices:\n", + " choice_id = tokenizer(' ' + choice)['input_ids'][1:-1] # Feature of SentencePiece tokenizer\n", + " input_ids = torch.tensor(context_id + [tokenizer.sop_token_id] + choice_id[:-1], dtype=torch.long)\n", + " attention_mask = torch.tril(torch.ones(len(input_ids), len(input_ids), dtype=torch.long))\n", + " attention_mask[:len(context_id), :len(context_id)] = 1\n", + " mask_position = context_id.index(tokenizer.mask_token_id)\n", + " position_id = torch.cat([torch.arange(len(context_id)), torch.ones(len(choice_id)) * mask_position])\n", + " block_position_id = torch.cat([torch.zeros(len(context_id)), torch.arange(1, 1 + len(choice_id))])\n", + " position_id = torch.stack((position_id, block_position_id), dim=0).long()\n", + " logits = model.forward(input_ids=input_ids.view(1, -1).cuda(),\n", + " attention_mask=attention_mask.unsqueeze(0).unsqueeze(0).cuda(),\n", + " position_ids=position_id.view(1, 2, -1).cuda())['logits']\n", + " logits = F.log_softmax(logits, dim=-1)\n", + " probs.append(logits[0, range(len(context_id), len(context_id) + len(choice_id)), choice_id].sum())\n", + " return torch.stack(probs)\n", + "\n", + "# print(\"Single sample:\", cond_log_prob_single_sample(\"One plus one equals two, is it correct? Answer: [MASK]\", [\"No\", \"Yes\"]))\n", + "\n", + "\n", + "# Forward results by single sample is slow. The following codes organize a batch of inputs to speed up training.\n", + "def build_multiple_choice_sample(context, choices):\n", + " context_id = tokenizer(context)['input_ids']\n", + "\n", + " division = len(context_id)\n", + " mask_position = context_id.index(tokenizer.mask_token_id)\n", + "\n", + " token = np.array(context_id, dtype=np.int64)\n", + " attention_mask = [np.ones((division, division), dtype=np.int64)]\n", + " position_id = np.arange(division, dtype=np.int64)\n", + " block_position_id = np.zeros(division, dtype=np.int64)\n", + "\n", + " choice_target_id = []\n", + " choice_id = []\n", + "\n", + " for choice_str in choices:\n", + " choice = np.array(tokenizer(choice_str)['input_ids'][1:-1], dtype=np.int64)\n", + "\n", + " choice_id.append(choice)\n", + " choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))\n", + " attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))\n", + "\n", + " token = np.concatenate((token, [tokenizer.sop_token_id], choice[:-1]))\n", + " position_id = np.concatenate((position_id, [mask_position] * len(choice)))\n", + " block_position_id = np.concatenate((block_position_id, np.arange(1, 1 + len(choice), dtype=np.int64)))\n", + "\n", + " attention_mask = block_diag(*attention_mask)\n", + " attention_mask[division:, :division] = 1\n", + "\n", + " return {\n", + " \"token\": token,\n", + " \"position_id\": np.stack((position_id, block_position_id)),\n", + " \"attention_mask\": attention_mask,\n", + " \"choices\": choice_id,\n", + " \"choice_target_ids\": choice_target_id\n", + " }\n", + "\n", + "\n", + "def pad_batch(tokens, position_ids, attention_mask, max_seq_length):\n", + " pad_length = max_seq_length - len(tokens)\n", + " attention_mask = np.pad(\n", + " attention_mask,\n", + " pad_width=((0, pad_length),),\n", + " mode=\"constant\",\n", + " constant_values=0,\n", + " )\n", + " tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))\n", + " position_ids = np.concatenate((position_ids, position_ids[..., -1:].repeat(pad_length, -1)), axis=-1)\n", + " return tokens, position_ids, attention_mask\n", + "\n", + "\n", + "def collate_fn(samples):\n", + " TILE = 16\n", + " length_to_pad = (max(map(lambda spl: len(spl[\"token\"]), samples)) + TILE - 1) // TILE * TILE\n", + "\n", + " token_batch, position_id_batch, attention_mask_batch = [], [], []\n", + " choices_batch, choice_target_ids_batch = [], []\n", + "\n", + " for sample in samples:\n", + " token, position_id, attention_mask = pad_batch(\n", + " sample[\"token\"], sample[\"position_id\"], sample[\"attention_mask\"], length_to_pad\n", + " )\n", + " token_batch.append(token)\n", + " position_id_batch.append(position_id)\n", + " attention_mask_batch.append(attention_mask)\n", + " choices_batch.append(sample[\"choices\"])\n", + " choice_target_ids_batch.append(sample[\"choice_target_ids\"])\n", + "\n", + " return {\n", + " \"tokens\": torch.tensor(np.array(token_batch), dtype=torch.int64),\n", + " \"position_ids\": torch.tensor(np.array(position_id_batch), dtype=torch.int64),\n", + " \"attention_mask\": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64),\n", + " \"choices\": choices_batch,\n", + " \"choice_target_ids\": choice_target_ids_batch,\n", + " }\n", + "\n", + "def cond_log_prob(context: List[str], choices: List[List[str]]) -> List[List[float]]:\n", + " \"\"\"\n", + " Compute conditonal probability for one or more continuation/infilling options.\n", + " :return The log probablity of each option.\n", + " \"\"\"\n", + " if not isinstance(context, list):\n", + " context = [context]\n", + " choices = [choices]\n", + " choices = [[(' ' + choice) for choice in choice_pair] for choice_pair in choices] # Feature of SentencePiece tokenizer\n", + "\n", + " samples = [build_multiple_choice_sample(ctx, ch) for ctx, ch in zip(context, choices)]\n", + "\n", + " batch = collate_fn(samples)\n", + "\n", + " logits = model.forward(input_ids=batch['tokens'].cuda(),\n", + " attention_mask=batch['attention_mask'].cuda().unsqueeze(1),\n", + " position_ids=batch['position_ids'].cuda())['logits']\n", + "\n", + " log_probs = []\n", + "\n", + " for output, choices, choice_target_ids in zip(F.log_softmax(logits, dim=-1), batch['choices'], batch['choice_target_ids']):\n", + " log_probs_single = []\n", + " for choice, choice_target_id in zip(choices, choice_target_ids):\n", + " tmp = output[choice_target_id, choice]\n", + " log_probs_single.append(tmp.sum())\n", + " log_probs.append(torch.stack(log_probs_single))\n", + "\n", + " return torch.stack(log_probs)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "abacd39a", + "metadata": {}, + "outputs": [], + "source": [ + "# Load `glm-roberta-large` model and tokenizer\n", + "\n", + "model_type = \"BAAI/glm-roberta-large\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_type, trust_remote_code=True, revision='main')\n", + "model = AutoModelForSeq2SeqLM.from_pretrained(model_type, trust_remote_code=True, revision='main').cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "69f29f75", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset art (/home/lk/.cache/huggingface/datasets/art/anli/0.1.0/e4b20acfcea873d587a87e817a63c02ce080bce28cd4c322dbd476fd07286b49)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt names: ['choose_hypothesis_options', 'choose_hypothesis_believable', 'choose_hypothesis', 'choose_hypothesis_desc', 'choose_hypothesis_likely']\n", + "Choices: [\"{{hypothesis_1| trim('.?!') }}\", \"{{hypothesis_2| trim('.?!') }}\"]\n" + ] + } + ], + "source": [ + "# Loading validation split of 'art' dataset using prompt from promptsource\n", + "\n", + "dataset = load_dataset(\"art\", split=\"validation\")\n", + "art_prompt = DatasetTemplates('art')\n", + "prompt = art_prompt[\"choose_hypothesis_desc\"]\n", + "print(\"Prompt names:\", [prompt.get_name() for prompt in art_prompt.templates.values()])\n", + "choices = prompt.answer_choices.split(' ||| ')\n", + "print(\"Choices:\", choices)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1b496c0f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6f9124aaea7a460b8a66f1f63d9668f3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1532 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plotting Training accuracy and loss with varying learning rates and reduced number of steps\n", + "\n", + "los = results[0]\n", + "los = [los[i] for i in range(0, len(los), 100)]\n", + "acc = results[2]\n", + "acc = [acc[i] for i in range(0, len(acc), 100)]\n", + "plt.ylim(0, 1)\n", + "plt.plot(list(range(len(los))), los) \n", + "plt.plot(list(range(len(acc))), acc)\n", + "plt.title('Training Loss and Accuracy')\n", + "plt.ylabel('Value')\n", + "plt.xlabel('Steps (Reduced)')\n", + "plt.legend(['Loss', 'Accuracy'], loc='upper right')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "f24981c7", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a41d9041d2f642e397c81eba15639029", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1532 [00:00 +Student ID: 2022280073 + +# Task Description + +'art' is a dataset consisting of observation and hypothesis pairs. Each entry of this dataset consists of two observations and two hypotheses. The task is to choose the best hypothesis based on the given observations. There is a label column in the dataset. The value on this column for each entry is either 1 or 2, refering to hypothesis 1 or hypothesis 2. The dataset has: + + +# Running Commands: + +Running the cells in the given notebook gives the desired results. + +Hyperparameters: + + +# Results: +To evaluate the model's performance, the validation was done both before and after fine tuning. These are the results for the whole process. + + +# Reference: +``` +@inproceedings {Bhagavatula2020Abductive, + title = {Abductive Commonsense Reasoning}, + author = {Chandra Bhagavatula and Ronan Le Bras and Chaitanya Malaviya and Keisuke Sakaguchi and Ari Holtzman and Hannah Rashkin and Doug Downey and Wen-tau Yih and Yejin Choi}, + booktitle = {International Conference on Learning Representations}, + year = {2020}, + url = {https://openreview.net/forum?id=Byg1v1HKDB} +} +``` \ No newline at end of file diff --git a/examples/art/requirements.txt b/examples/art/requirements.txt new file mode 100644 index 0000000..11753da --- /dev/null +++ b/examples/art/requirements.txt @@ -0,0 +1,11 @@ +python==3.9.13 +pytorch==1.13.1 +cuda-toolkit==11.6.1 +transformers==4.24.0 +scipy==1.5.0 +datasets==2.7.0 +promptsource==0.2.3 +sentencepiece +scikit-learn +tqdm +jupyterlab \ No newline at end of file