"
+ ]
+ },
+ "execution_count": 43,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import wandb\n",
+ "\n",
+ "wandb.login()\n",
+ "wandb.init(project='odqa', # ์คํ๊ธฐ๋ก์ ๊ด๋ฆฌํ ํ๋ก์ ํธ ์ด๋ฆ\n",
+ " entity='nlp15', # ์ฌ์ฉ์๋ช
๋๋ ํ ์ด๋ฆ \n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "training_args = TrainingArguments(\n",
+ " output_dir=\"outputs\",\n",
+ " do_train=True,\n",
+ " do_eval=True,\n",
+ " per_device_train_batch_size=batch_size,\n",
+ " per_device_eval_batch_size=batch_size,\n",
+ " num_train_epochs=num_train_epochs,\n",
+ " save_total_limit=1, # ์ ์ฅํ ์ฒดํฌํฌ์ธํธ์ ์ต๋ ์\n",
+ " #evaluation_strategy=\"steps\",\n",
+ " #eval_steps=500, # ๋ช ์คํ
๋ง๋ค ํ๊ฐํ ์ง ์ค์ \n",
+ " #logging_steps=500, # ๋ช ์คํ
๋ง๋ค ๋ก๊น
ํ ์ง ์ค์ ,\n",
+ " #load_best_model_at_end=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = QuestionAnsweringTrainer(\n",
+ " model=model,\n",
+ " args=training_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=eval_dataset,\n",
+ " eval_examples=val_dataset,\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=default_data_collator, # ๋ณดํต default\n",
+ " post_process_function=post_processing_function, # function์ input์ผ๋ก ๋ฐ์!\n",
+ " compute_metrics=compute_metrics\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
+ " warnings.warn(\n",
+ "***** Running training *****\n",
+ " Num examples = 62641\n",
+ " Num Epochs = 1\n",
+ " Instantaneous batch size per device = 16\n",
+ " Total train batch size (w. parallel, distributed & accumulation) = 16\n",
+ " Gradient Accumulation steps = 1\n",
+ " Total optimization steps = 3916\n",
+ " Number of trainable parameters = 356600834\n",
+ "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [3916/3916 1:43:53, Epoch 1/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 500 | \n",
+ " 0.769300 | \n",
+ "
\n",
+ " \n",
+ " | 1000 | \n",
+ " 0.541500 | \n",
+ "
\n",
+ " \n",
+ " | 1500 | \n",
+ " 0.523800 | \n",
+ "
\n",
+ " \n",
+ " | 2000 | \n",
+ " 0.483600 | \n",
+ "
\n",
+ " \n",
+ " | 2500 | \n",
+ " 0.410800 | \n",
+ "
\n",
+ " \n",
+ " | 3000 | \n",
+ " 0.379300 | \n",
+ "
\n",
+ " \n",
+ " | 3500 | \n",
+ " 0.350300 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Saving model checkpoint to outputs/checkpoint-500\n",
+ "Configuration saved in outputs/checkpoint-500/config.json\n",
+ "Model weights saved in outputs/checkpoint-500/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-500/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-500/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-3500] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-1000\n",
+ "Configuration saved in outputs/checkpoint-1000/config.json\n",
+ "Model weights saved in outputs/checkpoint-1000/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-1000/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-1000/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-500] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-1500\n",
+ "Configuration saved in outputs/checkpoint-1500/config.json\n",
+ "Model weights saved in outputs/checkpoint-1500/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-1500/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-1500/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-1000] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-2000\n",
+ "Configuration saved in outputs/checkpoint-2000/config.json\n",
+ "Model weights saved in outputs/checkpoint-2000/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-2000/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-2000/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-1500] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-2500\n",
+ "Configuration saved in outputs/checkpoint-2500/config.json\n",
+ "Model weights saved in outputs/checkpoint-2500/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-2500/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-2500/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-2000] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-3000\n",
+ "Configuration saved in outputs/checkpoint-3000/config.json\n",
+ "Model weights saved in outputs/checkpoint-3000/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-3000/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-3000/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-2500] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-3500\n",
+ "Configuration saved in outputs/checkpoint-3500/config.json\n",
+ "Model weights saved in outputs/checkpoint-3500/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-3500/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-3500/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-3000] due to args.save_total_limit\n",
+ "\n",
+ "\n",
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_result = trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "Run history:
| train/epoch | โโโโโ
โโโ |
| train/global_step | โโโโโ
โโโ |
| train/learning_rate | โโโโโโโ |
| train/loss | โโโโโโโ |
| train/total_flos | โ |
| train/train_loss | โ |
| train/train_runtime | โ |
| train/train_samples_per_second | โ |
| train/train_steps_per_second | โ |
Run summary:
| train/epoch | 1 |
| train/global_step | 3916 |
| train/learning_rate | 1e-05 |
| train/loss | 0.3503 |
| train/total_flos | 6.221469142067405e+16 |
| train/train_loss | 0.47893 |
| train/train_runtime | 6235.4731 |
| train/train_samples_per_second | 10.046 |
| train/train_steps_per_second | 0.628 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run dandy-rain-976 at: https://wandb.ai/nlp15/odqa/runs/ua6lvn4p
View project at: https://wandb.ai/nlp15/odqa
Synced 4 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Find logs at: ./wandb/run-20241017_040517-ua6lvn4p/logs"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CNN_RobertaForQuestionAnswering(\n",
+ " (roberta): RobertaModel(\n",
+ " (embeddings): RobertaEmbeddings(\n",
+ " (word_embeddings): Embedding(32000, 1024, padding_idx=1)\n",
+ " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n",
+ " (token_type_embeddings): Embedding(1, 1024)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): RobertaEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-23): 24 x RobertaLayer(\n",
+ " (attention): RobertaAttention(\n",
+ " (self): RobertaSelfAttention(\n",
+ " (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): RobertaSelfOutput(\n",
+ " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): RobertaIntermediate(\n",
+ " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): RobertaOutput(\n",
+ " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (cnn_block1): CNN_block(\n",
+ " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block2): CNN_block(\n",
+ " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block3): CNN_block(\n",
+ " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block4): CNN_block(\n",
+ " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block5): CNN_block(\n",
+ " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## ํ๊น
ํ์ด์ค์ ๋ชจ๋ธ ์
๋ก๋"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
+ "/bin/bash: sudo: command not found\n"
+ ]
+ }
+ ],
+ "source": [
+ "!sudo apt-get install git-lfs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Configuration saved in /tmp/tmp_r5qc71i/config.json\n",
+ "Model weights saved in /tmp/tmp_r5qc71i/pytorch_model.bin\n",
+ "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v1: pytorch_model.bin,config.json\n",
+ "pytorch_model.bin: 100%|โโโโโโโโโโ| 1.43G/1.43G [00:48<00:00, 29.1MB/s] \n",
+ "tokenizer config file saved in /tmp/tmpxrg7mu12/tokenizer_config.json\n",
+ "Special tokens file saved in /tmp/tmpxrg7mu12/special_tokens_map.json\n",
+ "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v1: tokenizer_config.json,vocab.txt,tokenizer.json,special_tokens_map.json\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "CommitInfo(commit_url='https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v1/commit/bd3c260ee793ce46ab444055515bf02be74f2a80', commit_message='Upload tokenizer', commit_description='', oid='bd3c260ee793ce46ab444055515bf02be74f2a80', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v1', endpoint='https://huggingface.co', repo_type='model', repo_id='ssunbear/klue_roberta_large_finetuned_korquad_v1'), pr_revision=None, pr_num=None)"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mํ์ฌ ์
๋๋ ์ด์ ์
์์ ์ฝ๋๋ฅผ ์คํํ๋ ๋์ Kernel์ด ์ถฉ๋ํ์ต๋๋ค. \n",
+ "\u001b[1;31m์
์ ์ฝ๋๋ฅผ ๊ฒํ ํ์ฌ ๊ฐ๋ฅํ ์ค๋ฅ ์์ธ์ ์๋ณํ์ธ์. \n",
+ "\u001b[1;31m์์ธํ ๋ด์ฉ์ ๋ณด๋ ค๋ฉด ์ฌ๊ธฐ๋ฅผ ํด๋ฆญํ์ธ์. \n",
+ "\u001b[1;31m์์ธํ ๋ด์ฉ์ Jupyter ๋ก๊ทธ๋ฅผ ์ฐธ์กฐํ์ธ์."
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoModel\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Huggingface Access Token\n",
+ "ACCESS_TOKEN = # ํ ํฐ์์ด๋ ์
๋ ฅํ์๋ฉด ๋ฉ๋๋ค\n",
+ "\n",
+ "# Upload to Huggingface\n",
+ "model.push_to_hub('klue_roberta_large_finetuned_korquad_v1', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n",
+ "tokenizer.push_to_hub('klue_roberta_large_finetuned_korquad_v1', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Method 2\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- Method2. Method1์ mrc_train ๋ฐ์ดํฐ์
์ผ๋ก ํ๋ฒ ๋ fine-tuning(๋ชจ๋ธ ์ฌํธ์ถ) -> ssunbear/klue_roberta_large_finetuned_korquad_v2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Pre-trained ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "loading configuration file config.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/config.json\n",
+ "Model config RobertaConfig {\n",
+ " \"_name_or_path\": \"ssunbear/klue_roberta_large_finetuned_korquad_v1\",\n",
+ " \"architectures\": [\n",
+ " \"RobertaForQuestionAnswering\"\n",
+ " ],\n",
+ " \"attention_probs_dropout_prob\": 0.1,\n",
+ " \"bos_token_id\": 0,\n",
+ " \"classifier_dropout\": null,\n",
+ " \"eos_token_id\": 2,\n",
+ " \"gradient_checkpointing\": false,\n",
+ " \"hidden_act\": \"gelu\",\n",
+ " \"hidden_dropout_prob\": 0.1,\n",
+ " \"hidden_size\": 1024,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 4096,\n",
+ " \"layer_norm_eps\": 1e-05,\n",
+ " \"max_position_embeddings\": 514,\n",
+ " \"model_type\": \"roberta\",\n",
+ " \"num_attention_heads\": 16,\n",
+ " \"num_hidden_layers\": 24,\n",
+ " \"pad_token_id\": 1,\n",
+ " \"position_embedding_type\": \"absolute\",\n",
+ " \"tokenizer_class\": \"BertTokenizer\",\n",
+ " \"torch_dtype\": \"float32\",\n",
+ " \"transformers_version\": \"4.24.0\",\n",
+ " \"type_vocab_size\": 1,\n",
+ " \"use_cache\": true,\n",
+ " \"vocab_size\": 32000\n",
+ "}\n",
+ "\n",
+ "loading file vocab.txt from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/vocab.txt\n",
+ "loading file tokenizer.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/tokenizer.json\n",
+ "loading file added_tokens.json from cache at None\n",
+ "loading file special_tokens_map.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/special_tokens_map.json\n",
+ "loading file tokenizer_config.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/tokenizer_config.json\n",
+ "loading weights file pytorch_model.bin from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/pytorch_model.bin\n",
+ "All model checkpoint weights were used when initializing RobertaForQuestionAnswering.\n",
+ "\n",
+ "All the weights of RobertaForQuestionAnswering were initialized from the model checkpoint at ssunbear/klue_roberta_large_finetuned_korquad_v1.\n",
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use RobertaForQuestionAnswering for predictions without further training.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import (\n",
+ " AutoConfig,\n",
+ " AutoModelForQuestionAnswering,\n",
+ " AutoTokenizer\n",
+ ")\n",
+ "\n",
+ "model_name = \"ssunbear/klue_roberta_large_finetuned_korquad_v1\"\n",
+ "\n",
+ "config = AutoConfig.from_pretrained(\n",
+ " model_name\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
+ " model_name,\n",
+ " use_fast=True\n",
+ ")\n",
+ "model = AutoModelForQuestionAnswering.from_pretrained(\n",
+ " model_name,\n",
+ " config=config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "bat"
+ }
+ },
+ "source": [
+ "## Train ๋ฐ์ดํฐ์
์ ์ฒ๋ฆฌ"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "max_seq_length = 512 # ์ง๋ฌธ๊ณผ ์ปจํ
์คํธ, special token์ ํฉํ ๋ฌธ์์ด์ ์ต๋ ๊ธธ์ด (์ผ์ ๊ฐ์๊ฐ ๋์ด๊ฐ์ง ์๋๋ก!)\n",
+ "pad_to_max_length = False\n",
+ "doc_stride = 128 # ์ปจํ
์คํธ๊ฐ ๋๋ฌด ๊ธธ์ด์ ๋๋ด์ ๋ ์ค๋ฒ๋ฉ๋๋ ์ํ์ค ๊ธธ์ด, ๋ฌธ์ 2๊ฐ๋ก ์ชผ๊ฐ๊ณ , 128๊ฐ ์ํ์ค๊ฐ ๊ฒน์น๋๋ก\n",
+ "preprocessing_num_workers = None\n",
+ "batch_size = 16\n",
+ "num_train_epochs = 4\n",
+ "n_best_size = 20\n",
+ "max_answer_length = 30\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def prepare_train_features(examples): # examples: ๋ฐ์ดํฐ์
row..\n",
+ " # ์ฃผ์ด์ง ํ
์คํธ๋ฅผ ํ ํฌ๋์ด์ง ํ๋ค. ์ด ๋ ํ
์คํธ์ ๊ธธ์ด๊ฐ max_seq_length๋ฅผ ๋์ผ๋ฉด stride๋งํผ ์ฌ๋ผ์ด๋ฉํ๋ฉฐ ์ฌ๋ฌ ๊ฐ๋ก ์ชผ๊ฐฌ.\n",
+ " # ์ฆ, ํ๋์ example์์ ์ผ๋ถ๋ถ์ด ๊ฒน์น๋ ์ฌ๋ฌ sequence(feature)๊ฐ ์๊ธธ ์ ์์.\n",
+ " \n",
+ " question_column_name = \"question\" if \"question\" in column_names else column_names[0]\n",
+ " context_column_name = \"context\" if \"context\" in column_names else column_names[1]\n",
+ " answer_column_name = \"answers\" if \"answers\" in column_names else column_names[2]\n",
+ "\n",
+ " # Padding์ ๋ํ ์ต์
์ ์ค์ ํฉ๋๋ค.\n",
+ " tokenized_examples = tokenizer(\n",
+ " examples[\"question\"],\n",
+ " examples[\"context\"],\n",
+ " truncation=\"only_second\", # max_seq_length๊น์ง truncateํ๋ค. pair์ ๋๋ฒ์งธ ํํธ(context)๋ง ์๋ผ๋.\n",
+ " max_length=max_seq_length,\n",
+ " stride=doc_stride,\n",
+ " return_overflowing_tokens=True, # ๊ธธ์ด๋ฅผ ๋์ด๊ฐ๋ ํ ํฐ๋ค์ ๋ฐํํ ๊ฒ์ธ์ง\n",
+ " return_offsets_mapping=True, # ๊ฐ ํ ํฐ์ ๋ํด (char_start, char_end) ์ ๋ณด๋ฅผ ๋ฐํํ ๊ฒ์ธ์ง\n",
+ " padding=\"max_length\", return_token_type_ids=False\n",
+ " )\n",
+ "\n",
+ " # example ํ๋๊ฐ ์ฌ๋ฌ sequence์ ๋์ํ๋ ๊ฒฝ์ฐ๋ฅผ ์ํด ๋งคํ์ด ํ์ํจ.\n",
+ " overflow_to_sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n",
+ " # offset_mappings์ผ๋ก ํ ํฐ์ด ์๋ณธ context ๋ด ๋ช๋ฒ์งธ ๊ธ์๋ถํฐ ๋ช๋ฒ์งธ ๊ธ์๊น์ง ํด๋นํ๋์ง ์ ์ ์์.\n",
+ " offset_mapping = tokenized_examples.pop(\"offset_mapping\")\n",
+ "\n",
+ " # ์ ๋ต์ง๋ฅผ ๋ง๋ค๊ธฐ ์ํ ๋ฆฌ์คํธ\n",
+ " tokenized_examples[\"start_positions\"] = []\n",
+ " tokenized_examples[\"end_positions\"] = []\n",
+ "\n",
+ " for i, offsets in enumerate(offset_mapping):\n",
+ " input_ids = tokenized_examples[\"input_ids\"][i]\n",
+ " cls_index = input_ids.index(tokenizer.cls_token_id)\n",
+ "\n",
+ " # ํด๋น example์ ํด๋นํ๋ sequence๋ฅผ ์ฐพ์.\n",
+ " sequence_ids = tokenized_examples.sequence_ids(i)\n",
+ "\n",
+ " # sequence๊ฐ ์ํ๋ example์ ์ฐพ๋๋ค.\n",
+ " example_index = overflow_to_sample_mapping[i]\n",
+ " answers = examples[\"answers\"][example_index]\n",
+ "\n",
+ " # ํ
์คํธ์์ answer์ ์์์ , ๋์ \n",
+ " answer_start_offset = answers[\"answer_start\"][0]\n",
+ " answer_end_offset = answer_start_offset + len(answers[\"text\"][0])\n",
+ "\n",
+ " # ํ
์คํธ์์ ํ์ฌ span์ ์์ ํ ํฐ ์ธ๋ฑ์ค\n",
+ " token_start_index = 0\n",
+ " while sequence_ids[token_start_index] != 1:\n",
+ " token_start_index += 1\n",
+ "\n",
+ " # ํ
์คํธ์์ ํ์ฌ span ๋ ํ ํฐ ์ธ๋ฑ์ค\n",
+ " token_end_index = len(input_ids) - 1\n",
+ " while sequence_ids[token_end_index] != 1:\n",
+ " token_end_index -= 1\n",
+ "\n",
+ " # answer๊ฐ ํ์ฌ span์ ๋ฒ์ด๋ฌ๋์ง ์ฒดํฌ\n",
+ " if not (offsets[token_start_index][0] <= answer_start_offset and offsets[token_end_index][1] >= answer_end_offset):\n",
+ " tokenized_examples[\"start_positions\"].append(cls_index)\n",
+ " tokenized_examples[\"end_positions\"].append(cls_index)\n",
+ " else:\n",
+ " # token_start_index์ token_end_index๋ฅผ answer์ ์์์ ๊ณผ ๋์ ์ผ๋ก ์ฎ๊น\n",
+ " while token_start_index < len(offsets) and offsets[token_start_index][0] <= answer_start_offset:\n",
+ " token_start_index += 1\n",
+ " tokenized_examples[\"start_positions\"].append(token_start_index - 1)\n",
+ " while offsets[token_end_index][1] >= answer_end_offset:\n",
+ " token_end_index -= 1\n",
+ " tokenized_examples[\"end_positions\"].append(token_end_index + 1)\n",
+ "\n",
+ " return tokenized_examples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Map: 100%|โโโโโโโโโโ| 3952/3952 [00:02<00:00, 1426.98 examples/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "column_names = mrc_train_dataset.column_names\n",
+ "train_dataset = mrc_train_dataset.map(\n",
+ " prepare_train_features,\n",
+ " batched=True,\n",
+ " num_proc=preprocessing_num_workers,\n",
+ " remove_columns=column_names,\n",
+ " load_from_cache_file=True,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Dataset({\n",
+ " features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],\n",
+ " num_rows: 240\n",
+ "})"
+ ]
+ },
+ "execution_count": 82,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from datasets import load_from_disk\n",
+ "\n",
+ "# ๋ฐ์ดํฐ์
๋ก๋\n",
+ "mrc_validation_dataset_path = \"/data/ephemeral/home/level2-mrc-nlp-15/data/train_dataset/validation\" # ์ค์ ๋ฐ์ดํฐ์
๊ฒฝ๋ก๋ก ์์ \n",
+ "mrc_validation_dataset = load_from_disk(mrc_validation_dataset_path)\n",
+ "mrc_validation_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# korquad ๋ฐ์ดํฐ์
์ด๋ ํ์ ๋๊ฐ์ด ๋ง๋ค์ด์ฃผ๊ธฐ\n",
+ "id_list0 = []\n",
+ "title_list0 = []\n",
+ "context_list0 = []\n",
+ "question_list0 = []\n",
+ "answers_list0 = []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[]"
+ ]
+ },
+ "execution_count": 84,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "id_list0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "for index, row in pd.DataFrame(mrc_validation_dataset).iterrows():\n",
+ " id_list0.append(row['id'])\n",
+ " title_list0.append(str(row['title']))\n",
+ " context_list0.append(str(row['context']))\n",
+ " question_list0.append(str(row['question']))\n",
+ " answers_list0.append(row['answers'])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mrc_validation_dataset = {\n",
+ " \"id\" : id_list0,\n",
+ " \"title\" : title_list0,\n",
+ " \"context\" : context_list0,\n",
+ " \"question\" : question_list0,\n",
+ " \"answers\" : answers_list0,}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Dataset({\n",
+ " features: ['id', 'title', 'context', 'question', 'answers'],\n",
+ " num_rows: 240\n",
+ "})"
+ ]
+ },
+ "execution_count": 91,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from datasets import Dataset\n",
+ "\n",
+ "mrc_validation_dataset= Dataset.from_dict(mrc_validation_dataset)\n",
+ "mrc_validation_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def prepare_validation_features(examples):\n",
+ " tokenized_examples = tokenizer(\n",
+ " examples['question'],\n",
+ " examples['context'],\n",
+ " truncation=\"only_second\",\n",
+ " max_length=max_seq_length,\n",
+ " stride=doc_stride,\n",
+ " return_overflowing_tokens=True,\n",
+ " return_offsets_mapping=True,\n",
+ " padding=\"max_length\",\n",
+ " )\n",
+ "\n",
+ " sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n",
+ "\n",
+ " tokenized_examples[\"example_id\"] = []\n",
+ "\n",
+ " for i in range(len(tokenized_examples[\"input_ids\"])):\n",
+ " sequence_ids = tokenized_examples.sequence_ids(i)\n",
+ " context_index = 1\n",
+ "\n",
+ " sample_index = sample_mapping[i]\n",
+ " tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\n",
+ "\n",
+ " tokenized_examples[\"offset_mapping\"][i] = [\n",
+ " (o if sequence_ids[k] == context_index else None)\n",
+ " for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\n",
+ " ]\n",
+ "\n",
+ " return tokenized_examples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Map: 100%|โโโโโโโโโโ| 240/240 [00:00<00:00, 846.63 examples/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "eval_dataset = mrc_validation_dataset.map(\n",
+ " prepare_validation_features,\n",
+ " batched=True,\n",
+ " num_proc=preprocessing_num_workers,\n",
+ " remove_columns=column_names,\n",
+ " load_from_cache_file=True,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Question Answering Class ์ ์"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# default_data_collator: ์ฌ๋ฌ๊ฐ example๋ค์ collatorํด์ฃผ๋ ์ญํ ,\n",
+ "# TrainingArguments : ํ๋ฒ์ training arguments๋ค์ ํฉ์ณ์ ์ฃผ๋..!\n",
+ "from transformers import default_data_collator, TrainingArguments, EvalPrediction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 94,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# coding=utf-8\n",
+ "# Copyright 2020 The HuggingFace Team All rights reserved.\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License.\n",
+ "\"\"\"\n",
+ "Question-Answering task์ ๊ด๋ จ๋ 'Trainer'์ subclass ์ฝ๋ ์
๋๋ค.\n",
+ "\"\"\"\n",
+ "\n",
+ "from transformers import Trainer, is_datasets_available, is_torch_tpu_available\n",
+ "from transformers.trainer_utils import PredictionOutput\n",
+ "\n",
+ "if is_datasets_available():\n",
+ " import datasets\n",
+ "\n",
+ "# Huggingface์ Trainer๋ฅผ ์์๋ฐ์ QuestionAnswering์ ์ํ Trainer๋ฅผ ์์ฑํฉ๋๋ค.\n",
+ "class QuestionAnsweringTrainer(Trainer):\n",
+ " def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):\n",
+ " super().__init__(*args, **kwargs)\n",
+ " self.eval_examples = eval_examples\n",
+ " self.post_process_function = post_process_function\n",
+ "\n",
+ " def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):\n",
+ " eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset\n",
+ " eval_dataloader = self.get_eval_dataloader(eval_dataset)\n",
+ " eval_examples = self.eval_examples if eval_examples is None else eval_examples\n",
+ "\n",
+ " # ์ผ์์ ์ผ๋ก metric computation๋ฅผ ๋ถ๊ฐ๋ฅํ๊ฒ ํ ์ํ์ด๋ฉฐ, ํด๋น ์ฝ๋์์๋ loop ๋ด์์ metric ๊ณ์ฐ์ ์ํํฉ๋๋ค.\n",
+ " compute_metrics = self.compute_metrics\n",
+ " self.compute_metrics = None\n",
+ " try:\n",
+ " output = self.prediction_loop(\n",
+ " eval_dataloader,\n",
+ " description=\"Evaluation\",\n",
+ " # metric์ด ์์ผ๋ฉด ์์ธก๊ฐ์ ๋ชจ์ผ๋ ์ด์ ๊ฐ ์์ผ๋ฏ๋ก ์๋์ ์ฝ๋๋ฅผ ๋ฐ๋ฅด๊ฒ ๋ฉ๋๋ค.\n",
+ " # self.args.prediction_loss_only\n",
+ " prediction_loss_only=True if compute_metrics is None else None,\n",
+ " ignore_keys=ignore_keys,\n",
+ " )\n",
+ " finally:\n",
+ " self.compute_metrics = compute_metrics\n",
+ "\n",
+ " if isinstance(eval_dataset, datasets.Dataset):\n",
+ " eval_dataset.set_format(\n",
+ " type=eval_dataset.format[\"type\"],\n",
+ " columns=list(eval_dataset.features.keys()),\n",
+ " )\n",
+ "\n",
+ " if self.post_process_function is not None and self.compute_metrics is not None:\n",
+ " eval_preds = self.post_process_function(\n",
+ " eval_examples, eval_dataset, output.predictions, self.args\n",
+ " )\n",
+ " metrics = self.compute_metrics(eval_preds)\n",
+ "\n",
+ " self.log(metrics)\n",
+ " else:\n",
+ " metrics = {}\n",
+ "\n",
+ " self.control = self.callback_handler.on_evaluate(\n",
+ " self.args, self.state, self.control, metrics\n",
+ " )\n",
+ " return metrics\n",
+ "\n",
+ " def predict(self, test_dataset, test_examples, ignore_keys=None):\n",
+ " test_dataloader = self.get_test_dataloader(test_dataset)\n",
+ "\n",
+ " # ์ผ์์ ์ผ๋ก metric computation๋ฅผ ๋ถ๊ฐ๋ฅํ๊ฒ ํ ์ํ์ด๋ฉฐ, ํด๋น ์ฝ๋์์๋ loop ๋ด์์ metric ๊ณ์ฐ์ ์ํํฉ๋๋ค.\n",
+ " # evaluate ํจ์์ ๋์ผํ๊ฒ ๊ตฌ์ฑ๋์ด์์ต๋๋ค\n",
+ " compute_metrics = self.compute_metrics\n",
+ " self.compute_metrics = None\n",
+ " try:\n",
+ " output = self.prediction_loop(\n",
+ " test_dataloader,\n",
+ " description=\"Evaluation\",\n",
+ " # metric์ด ์์ผ๋ฉด ์์ธก๊ฐ์ ๋ชจ์ผ๋ ์ด์ ๊ฐ ์์ผ๋ฏ๋ก ์๋์ ์ฝ๋๋ฅผ ๋ฐ๋ฅด๊ฒ ๋ฉ๋๋ค.\n",
+ " # self.args.prediction_loss_only\n",
+ " prediction_loss_only=True if compute_metrics is None else None,\n",
+ " ignore_keys=ignore_keys,\n",
+ " )\n",
+ " finally:\n",
+ " self.compute_metrics = compute_metrics\n",
+ "\n",
+ " if self.post_process_function is None or self.compute_metrics is None:\n",
+ " return output\n",
+ "\n",
+ " if isinstance(test_dataset, datasets.Dataset):\n",
+ " test_dataset.set_format(\n",
+ " type=test_dataset.format[\"type\"],\n",
+ " columns=list(test_dataset.features.keys()),\n",
+ " )\n",
+ "\n",
+ " predictions = self.post_process_function(\n",
+ " test_examples, test_dataset, output.predictions, self.args\n",
+ " )\n",
+ " return predictions\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## ํ์ฒ๋ฆฌ ํด๋์ค ์ ์"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 95,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# coding=utf-8\n",
+ "# Copyright 2020 The HuggingFace Team All rights reserved.\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the 'License');\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an 'AS IS' BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License.\n",
+ "\"\"\"\n",
+ "Pre-processing\n",
+ "Post-processing utilities for question answering.\n",
+ "\"\"\"\n",
+ "import collections\n",
+ "import json\n",
+ "import logging\n",
+ "import os\n",
+ "import random\n",
+ "from typing import Any, Optional, Tuple\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from arguments import DataTrainingArguments, ModelArguments\n",
+ "from datasets import DatasetDict\n",
+ "from tqdm.auto import tqdm\n",
+ "from transformers import PreTrainedTokenizerFast, TrainingArguments, is_torch_available\n",
+ "from transformers.trainer_utils import get_last_checkpoint\n",
+ "\n",
+ "#from utils.datetime_helper import get_seoul_datetime_str\n",
+ "\n",
+ "logger = logging.getLogger(__name__)\n",
+ "\n",
+ "\n",
+ "def set_seed(seed: int = 2024):\n",
+ " \"\"\"\n",
+ " seed ๊ณ ์ ํ๋ ํจ์ (random, numpy, torch)\n",
+ "\n",
+ " Args:\n",
+ " seed (:obj:`int`): The seed to set.\n",
+ " \"\"\"\n",
+ " random.seed(seed)\n",
+ " np.random.seed(seed)\n",
+ " if is_torch_available():\n",
+ " torch.manual_seed(seed)\n",
+ " torch.cuda.manual_seed(seed)\n",
+ " torch.cuda.manual_seed_all(seed) # if use multi-GPU\n",
+ " torch.backends.cudnn.deterministic = True\n",
+ " torch.backends.cudnn.benchmark = False\n",
+ "\n",
+ "\n",
+ "def postprocess_qa_predictions(\n",
+ " examples,\n",
+ " features,\n",
+ " predictions: Tuple[np.ndarray, np.ndarray],\n",
+ " version_2_with_negative: bool = False,\n",
+ " n_best_size: int = 20,\n",
+ " max_answer_length: int = 30,\n",
+ " null_score_diff_threshold: float = 0.0,\n",
+ " output_dir: Optional[str] = None,\n",
+ " prefix: Optional[str] = None,\n",
+ " is_world_process_zero: bool = True,\n",
+ "):\n",
+ " \"\"\"\n",
+ " Post-processes : qa model์ prediction ๊ฐ์ ํ์ฒ๋ฆฌํ๋ ํจ์\n",
+ " ๋ชจ๋ธ์ start logit๊ณผ end logit์ ๋ฐํํ๊ธฐ ๋๋ฌธ์, ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก original text๋ก ๋ณ๊ฒฝํ๋ ํ์ฒ๋ฆฌ๊ฐ ํ์ํจ\n",
+ "\n",
+ " Args:\n",
+ " examples: ์ ์ฒ๋ฆฌ ๋์ง ์์ ๋ฐ์ดํฐ์
(see the main script for more information).\n",
+ " features: ์ ์ฒ๋ฆฌ๊ฐ ์งํ๋ ๋ฐ์ดํฐ์
(see the main script for more information).\n",
+ " predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):\n",
+ " ๋ชจ๋ธ์ ์์ธก๊ฐ :start logits๊ณผ the end logits์ ๋ํ๋ด๋ two arrays ์ฒซ๋ฒ์งธ ์ฐจ์์ :obj:`features`์ element์ ๊ฐฏ์๊ฐ ๋ง์์ผํจ.\n",
+ " version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):\n",
+ " ์ ๋ต์ด ์๋ ๋ฐ์ดํฐ์
์ด ํฌํจ๋์ด์๋์ง ์ฌ๋ถ๋ฅผ ๋ํ๋\n",
+ " n_best_size (:obj:`int`, `optional`, defaults to 20):\n",
+ " ๋ต๋ณ์ ์ฐพ์ ๋ ์์ฑํ n-best prediction ์ด ๊ฐ์\n",
+ " max_answer_length (:obj:`int`, `optional`, defaults to 30):\n",
+ " ์์ฑํ ์ ์๋ ๋ต๋ณ์ ์ต๋ ๊ธธ์ด\n",
+ " null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):\n",
+ " null ๋ต๋ณ์ ์ ํํ๋ ๋ฐ ์ฌ์ฉ๋๋ threshold\n",
+ " : if the best answer has a score that is less than the score of\n",
+ " the null answer minus this threshold, the null answer is selected for this example (note that the score of\n",
+ " the null answer for an example giving several features is the minimum of the scores for the null answer on\n",
+ " each feature: all features must be aligned on the fact they `want` to predict a null answer).\n",
+ " Only useful when :obj:`version_2_with_negative` is :obj:`True`.\n",
+ " output_dir (:obj:`str`, `optional`):\n",
+ " ์๋์ ๊ฐ์ด ์ ์ฅ๋๋ ๊ฒฝ๋ก\n",
+ " dictionary : predictions, n_best predictions (with their scores and logits) if:obj:`version_2_with_negative=True`,\n",
+ " dictionary : the scores differences between best and null answers\n",
+ " prefix (:obj:`str`, `optional`):\n",
+ " dictionary์ `prefix`๊ฐ ํฌํจ๋์ด ์ ์ฅ๋จ\n",
+ " is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):\n",
+ " ์ด ํ๋ก์ธ์ค๊ฐ main process์ธ์ง ์ฌ๋ถ(logging/save๋ฅผ ์ํํด์ผ ํ๋์ง ์ฌ๋ถ๋ฅผ ๊ฒฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋จ)\n",
+ " \"\"\"\n",
+ " assert (\n",
+ " len(predictions) == 2\n",
+ " ), \"`predictions` should be a tuple with two elements (start_logits, end_logits).\"\n",
+ " all_start_logits, all_end_logits = predictions\n",
+ "\n",
+ " assert len(predictions[0]) == len(\n",
+ " features\n",
+ " ), f\"Got {len(predictions[0])} predictions and {len(features)} features.\"\n",
+ "\n",
+ " # example๊ณผ mapping๋๋ feature ์์ฑ\n",
+ " example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n",
+ " features_per_example = collections.defaultdict(list)\n",
+ " for i, feature in enumerate(features):\n",
+ " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\n",
+ "\n",
+ " # prediction, nbest์ ํด๋นํ๋ OrderedDict ์์ฑํฉ๋๋ค.\n",
+ " all_predictions = collections.OrderedDict()\n",
+ " all_nbest_json = collections.OrderedDict()\n",
+ " if version_2_with_negative:\n",
+ " scores_diff_json = collections.OrderedDict()\n",
+ "\n",
+ " # Logging.\n",
+ " logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)\n",
+ " logger.info(\n",
+ " f\"Post-processing {len(examples)} example predictions split into {len(features)} features.\"\n",
+ " )\n",
+ "\n",
+ " # ์ ์ฒด example๋ค์ ๋ํ main Loop\n",
+ " for example_index, example in enumerate(tqdm(examples)):\n",
+ " # ํด๋นํ๋ ํ์ฌ example index\n",
+ " feature_indices = features_per_example[example_index]\n",
+ "\n",
+ " min_null_prediction = None\n",
+ " prelim_predictions = []\n",
+ "\n",
+ " # ํ์ฌ example์ ๋ํ ๋ชจ๋ feature ์์ฑํฉ๋๋ค.\n",
+ " for feature_index in feature_indices:\n",
+ " # ๊ฐ featureure์ ๋ํ ๋ชจ๋ prediction์ ๊ฐ์ ธ์ต๋๋ค.\n",
+ " start_logits = all_start_logits[feature_index]\n",
+ " end_logits = all_end_logits[feature_index]\n",
+ " # logit๊ณผ original context์ logit์ mappingํฉ๋๋ค.\n",
+ " offset_mapping = features[feature_index][\"offset_mapping\"]\n",
+ " # Optional : `token_is_max_context`, ์ ๊ณต๋๋ ๊ฒฝ์ฐ ํ์ฌ ๊ธฐ๋ฅ์์ ์ฌ์ฉํ ์ ์๋ max context๊ฐ ์๋ answer๋ฅผ ์ ๊ฑฐํฉ๋๋ค\n",
+ " token_is_max_context = features[feature_index].get(\n",
+ " \"token_is_max_context\", None\n",
+ " )\n",
+ "\n",
+ " # minimum null prediction์ ์
๋ฐ์ดํธ ํฉ๋๋ค.\n",
+ " feature_null_score = start_logits[0] + end_logits[0]\n",
+ " if (\n",
+ " min_null_prediction is None\n",
+ " or min_null_prediction[\"score\"] > feature_null_score\n",
+ " ):\n",
+ " min_null_prediction = {\n",
+ " \"offsets\": (0, 0),\n",
+ " \"score\": feature_null_score,\n",
+ " \"start_logit\": start_logits[0],\n",
+ " \"end_logit\": end_logits[0],\n",
+ " }\n",
+ "\n",
+ " # `n_best_size`๋ณด๋ค ํฐ start and end logits์ ์ดํด๋ด
๋๋ค.\n",
+ " start_indexes = np.argsort(start_logits)[\n",
+ " -1 : -n_best_size - 1 : -1\n",
+ " ].tolist()\n",
+ "\n",
+ " end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
+ "\n",
+ " for start_index in start_indexes:\n",
+ " for end_index in end_indexes:\n",
+ " # out-of-scope answers๋ ๊ณ ๋ คํ์ง ์์ต๋๋ค.\n",
+ " if (\n",
+ " start_index >= len(offset_mapping)\n",
+ " or end_index >= len(offset_mapping)\n",
+ " or offset_mapping[start_index] is None\n",
+ " or offset_mapping[end_index] is None\n",
+ " ):\n",
+ " continue\n",
+ " # ๊ธธ์ด๊ฐ < 0 ๋๋ > max_answer_length์ธ answer๋ ๊ณ ๋ คํ์ง ์์ต๋๋ค.\n",
+ " if (\n",
+ " end_index < start_index\n",
+ " or end_index - start_index + 1 > max_answer_length\n",
+ " ):\n",
+ " continue\n",
+ " # ์ต๋ context๊ฐ ์๋ answer๋ ๊ณ ๋ คํ์ง ์์ต๋๋ค.\n",
+ " if (\n",
+ " token_is_max_context is not None\n",
+ " and not token_is_max_context.get(str(start_index), False)\n",
+ " ):\n",
+ " continue\n",
+ " prelim_predictions.append(\n",
+ " {\n",
+ " \"offsets\": (\n",
+ " offset_mapping[start_index][0],\n",
+ " offset_mapping[end_index][1],\n",
+ " ),\n",
+ " \"score\": start_logits[start_index] + end_logits[end_index],\n",
+ " \"start_logit\": start_logits[start_index],\n",
+ " \"end_logit\": end_logits[end_index],\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " if version_2_with_negative:\n",
+ " # minimum null prediction์ ์ถ๊ฐํฉ๋๋ค.\n",
+ " prelim_predictions.append(min_null_prediction)\n",
+ " null_score = min_null_prediction[\"score\"]\n",
+ "\n",
+ " # ๊ฐ์ฅ ์ข์ `n_best_size` predictions๋ง ์ ์งํฉ๋๋ค.\n",
+ " predictions = sorted(\n",
+ " prelim_predictions, key=lambda x: x[\"score\"], reverse=True\n",
+ " )[:n_best_size]\n",
+ "\n",
+ " # ๋ฎ์ ์ ์๋ก ์ธํด ์ ๊ฑฐ๋ ๊ฒฝ์ฐ minimum null prediction์ ๋ค์ ์ถ๊ฐํฉ๋๋ค.\n",
+ " if version_2_with_negative and not any(\n",
+ " p[\"offsets\"] == (0, 0) for p in predictions\n",
+ " ):\n",
+ " predictions.append(min_null_prediction)\n",
+ "\n",
+ " # offset์ ์ฌ์ฉํ์ฌ original context์์ answer text๋ฅผ ์์งํฉ๋๋ค.\n",
+ " context = example[\"context\"]\n",
+ " for pred in predictions:\n",
+ " offsets = pred.pop(\"offsets\")\n",
+ " pred[\"text\"] = context[offsets[0] : offsets[1]]\n",
+ "\n",
+ " # rare edge case์๋ null์ด ์๋ ์์ธก์ด ํ๋๋ ์์ผ๋ฉฐ failure๋ฅผ ํผํ๊ธฐ ์ํด fake prediction์ ๋ง๋ญ๋๋ค.\n",
+ " if len(predictions) == 0 or (\n",
+ " len(predictions) == 1 and predictions[0][\"text\"] == \"\"\n",
+ " ):\n",
+ "\n",
+ " predictions.insert(\n",
+ " 0, {\"text\": \"empty\", \"start_logit\": 0.0, \"end_logit\": 0.0, \"score\": 0.0}\n",
+ " )\n",
+ "\n",
+ " # ๋ชจ๋ ์ ์์ ์ํํธ๋งฅ์ค๋ฅผ ๊ณ์ฐํฉ๋๋ค(we do it with numpy to stay independent from torch/tf in this file, using the LogSumExp trick).\n",
+ " scores = np.array([pred.pop(\"score\") for pred in predictions])\n",
+ " exp_scores = np.exp(scores - np.max(scores))\n",
+ " probs = exp_scores / exp_scores.sum()\n",
+ "\n",
+ " # ์์ธก๊ฐ์ ํ๋ฅ ์ ํฌํจํฉ๋๋ค.\n",
+ " for prob, pred in zip(probs, predictions):\n",
+ " pred[\"probability\"] = prob\n",
+ "\n",
+ " # best prediction์ ์ ํํฉ๋๋ค.\n",
+ " if not version_2_with_negative:\n",
+ " all_predictions[example[\"id\"]] = predictions[0][\"text\"]\n",
+ " else:\n",
+ " # else case : ๋จผ์ ๋น์ด ์์ง ์์ ์ต์์ ์์ธก์ ์ฐพ์์ผ ํฉ๋๋ค\n",
+ " i = 0\n",
+ " while predictions[i][\"text\"] == \"\":\n",
+ " i += 1\n",
+ " best_non_null_pred = predictions[i]\n",
+ "\n",
+ " # threshold๋ฅผ ์ฌ์ฉํด์ null prediction์ ๋น๊ตํฉ๋๋ค.\n",
+ " score_diff = (\n",
+ " null_score\n",
+ " - best_non_null_pred[\"start_logit\"]\n",
+ " - best_non_null_pred[\"end_logit\"]\n",
+ " )\n",
+ " scores_diff_json[example[\"id\"]] = float(score_diff) # JSON-serializable ๊ฐ๋ฅ\n",
+ " if score_diff > null_score_diff_threshold:\n",
+ " all_predictions[example[\"id\"]] = \"\"\n",
+ " else:\n",
+ " all_predictions[example[\"id\"]] = best_non_null_pred[\"text\"]\n",
+ "\n",
+ " # np.float๋ฅผ ๋ค์ float๋ก casting -> `predictions`์ JSON-serializable ๊ฐ๋ฅ\n",
+ " all_nbest_json[example[\"id\"]] = [\n",
+ " {\n",
+ " k: (\n",
+ " float(v)\n",
+ " if isinstance(v, (np.float16, np.float32, np.float64))\n",
+ " else v\n",
+ " )\n",
+ " for k, v in pred.items()\n",
+ " }\n",
+ " for pred in predictions\n",
+ " ]\n",
+ "\n",
+ " # output_dir์ด ์์ผ๋ฉด ๋ชจ๋ dicts๋ฅผ ์ ์ฅํฉ๋๋ค.\n",
+ " if output_dir is not None:\n",
+ " assert os.path.isdir(output_dir), f\"{output_dir} is not a directory.\"\n",
+ "\n",
+ " prediction_file = os.path.join(\n",
+ " output_dir,\n",
+ " \"predictions.json\" if prefix is None else f\"predictions_{prefix}.json\",\n",
+ " )\n",
+ " nbest_file = os.path.join(\n",
+ " output_dir,\n",
+ " \"nbest_predictions.json\"\n",
+ " if prefix is None\n",
+ " else f\"nbest_predictions_{prefix}.json\",\n",
+ " )\n",
+ " if version_2_with_negative:\n",
+ " null_odds_file = os.path.join(\n",
+ " output_dir,\n",
+ " \"null_odds.json\" if prefix is None else f\"null_odds_{prefix}.json\",\n",
+ " )\n",
+ "\n",
+ " logger.info(f\"Saving predictions to {prediction_file}.\")\n",
+ " with open(prediction_file, \"w\", encoding=\"utf-8\") as writer:\n",
+ " writer.write(\n",
+ " json.dumps(all_predictions, indent=4, ensure_ascii=False) + \"\\n\"\n",
+ " )\n",
+ " logger.info(f\"Saving nbest_preds to {nbest_file}.\")\n",
+ " with open(nbest_file, \"w\", encoding=\"utf-8\") as writer:\n",
+ " writer.write(\n",
+ " json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + \"\\n\"\n",
+ " )\n",
+ " if version_2_with_negative:\n",
+ " logger.info(f\"Saving null_odds to {null_odds_file}.\")\n",
+ " with open(null_odds_file, \"w\", encoding=\"utf-8\") as writer:\n",
+ " writer.write(\n",
+ " json.dumps(scores_diff_json, indent=4, ensure_ascii=False) + \"\\n\"\n",
+ " )\n",
+ "\n",
+ " return all_predictions\n",
+ "\n",
+ "\n",
+ "def check_no_error(\n",
+ " data_args: DataTrainingArguments,\n",
+ " training_args: TrainingArguments,\n",
+ " datasets: DatasetDict,\n",
+ " tokenizer,\n",
+ ") -> Tuple[Any, int]:\n",
+ "\n",
+ " # last checkpoint ์ฐพ๊ธฐ.\n",
+ " last_checkpoint = None\n",
+ " if (\n",
+ " os.path.isdir(training_args.output_dir)\n",
+ " and training_args.do_train\n",
+ " and not training_args.overwrite_output_dir\n",
+ " ):\n",
+ " last_checkpoint = get_last_checkpoint(training_args.output_dir)\n",
+ " if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:\n",
+ " raise ValueError(\n",
+ " f\"Output directory ({training_args.output_dir}) already exists and is not empty. \"\n",
+ " \"Use --overwrite_output_dir to overcome.\"\n",
+ " )\n",
+ " elif last_checkpoint is not None:\n",
+ " logger.info(\n",
+ " f\"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change \"\n",
+ " \"the `--output_dir` or add `--overwrite_output_dir` to train from scratch.\"\n",
+ " )\n",
+ "\n",
+ " # Tokenizer check: ํด๋น script๋ Fast tokenizer๋ฅผ ํ์๋กํฉ๋๋ค.\n",
+ " if not isinstance(tokenizer, PreTrainedTokenizerFast):\n",
+ " raise ValueError(\n",
+ " \"This example script only works for models that have a fast tokenizer. Checkout the big table of models \"\n",
+ " \"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this \"\n",
+ " \"requirement\"\n",
+ " )\n",
+ "\n",
+ " if data_args.max_seq_length > tokenizer.model_max_length:\n",
+ " logger.warn(\n",
+ " f\"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the\"\n",
+ " f\"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.\"\n",
+ " )\n",
+ " max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)\n",
+ "\n",
+ " if \"validation\" not in datasets:\n",
+ " raise ValueError(\"--do_eval requires a validation dataset\")\n",
+ " return last_checkpoint, max_seq_length\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### ํ์ฒ๋ฆฌ ํจ์ ์ ์"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ๋ชจ๋ธ์ด ์ดํดํ๋ ํํ์์ ์ฌ๋์ด ์ดํดํ๋ ํํ๋ก ๋ต๋ณ ๋งค์นญ\n",
+ "def post_processing_function(examples, features, predictions):\n",
+ " # Post-processing: we match the start logits and end logits to answers in the original context.\n",
+ " predictions = postprocess_qa_predictions(\n",
+ " examples=examples,\n",
+ " features=features,\n",
+ " predictions=predictions,\n",
+ " version_2_with_negative=False,\n",
+ " n_best_size=n_best_size,\n",
+ " max_answer_length=max_answer_length,\n",
+ " null_score_diff_threshold=0.0,\n",
+ " output_dir=training_args.output_dir,\n",
+ " is_world_process_zero=trainer.is_world_process_zero(),\n",
+ " )\n",
+ "\n",
+ " # Format the result to the format the metric expects.\n",
+ " formatted_predictions = [{\"id\": k, \"prediction_text\": v} for k, v in predictions.items()]\n",
+ " references = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in datasets[\"validation\"]]\n",
+ " return EvalPrediction(predictions=formatted_predictions, label_ids=references)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def compute_metrics(p: EvalPrediction):\n",
+ " return metric.compute(predictions=p.predictions, references=p.label_ids)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train 2์ฐจ"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "Tracking run with wandb version 0.18.3"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Run data is saved locally in /data/ephemeral/home/level2-mrc-nlp-15/src/wandb/run-20241015_184122-0y11f1su"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Syncing run electric-serenity-822 to Weights & Biases (docs)
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View project at https://wandb.ai/nlp15/odqa"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run at https://wandb.ai/nlp15/odqa/runs/0y11f1su"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 98,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import wandb\n",
+ "\n",
+ "wandb.login()\n",
+ "wandb.init(project='odqa', # ์คํ๊ธฐ๋ก์ ๊ด๋ฆฌํ ํ๋ก์ ํธ ์ด๋ฆ\n",
+ " entity='nlp15', # ์ฌ์ฉ์๋ช
๋๋ ํ ์ด๋ฆ \n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 99,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "PyTorch: setting up devices\n",
+ "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
+ ]
+ }
+ ],
+ "source": [
+ "training_args = TrainingArguments(\n",
+ " output_dir=\"outputs\",\n",
+ " do_train=True,\n",
+ " do_eval=True,\n",
+ " per_device_train_batch_size=batch_size,\n",
+ " per_device_eval_batch_size=batch_size,\n",
+ " num_train_epochs=num_train_epochs,\n",
+ " save_total_limit=1, # ์ ์ฅํ ์ฒดํฌํฌ์ธํธ์ ์ต๋ ์\n",
+ " #evaluation_strategy=\"steps\",\n",
+ " #eval_steps=500, # ๋ช ์คํ
๋ง๋ค ํ๊ฐํ ์ง ์ค์ \n",
+ " #logging_steps=500, # ๋ช ์คํ
๋ง๋ค ๋ก๊น
ํ ์ง ์ค์ ,\n",
+ " #load_best_model_at_end=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = QuestionAnsweringTrainer(\n",
+ " model=model,\n",
+ " args=training_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=eval_dataset,\n",
+ " eval_examples=val_dataset,\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=default_data_collator, # ๋ณดํต default\n",
+ " post_process_function=post_processing_function, # function์ input์ผ๋ก ๋ฐ์!\n",
+ " compute_metrics=compute_metrics\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 102,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
+ " warnings.warn(\n",
+ "***** Running training *****\n",
+ " Num examples = 5769\n",
+ " Num Epochs = 4\n",
+ " Instantaneous batch size per device = 16\n",
+ " Total train batch size (w. parallel, distributed & accumulation) = 16\n",
+ " Gradient Accumulation steps = 1\n",
+ " Total optimization steps = 1444\n",
+ " Number of trainable parameters = 335608834\n",
+ "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1444/1444 35:59, Epoch 4/4]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 500 | \n",
+ " 0.790100 | \n",
+ "
\n",
+ " \n",
+ " | 1000 | \n",
+ " 0.258400 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Saving model checkpoint to outputs/checkpoint-500\n",
+ "Configuration saved in outputs/checkpoint-500/config.json\n",
+ "Model weights saved in outputs/checkpoint-500/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-500/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-500/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-3500] due to args.save_total_limit\n",
+ "Saving model checkpoint to outputs/checkpoint-1000\n",
+ "Configuration saved in outputs/checkpoint-1000/config.json\n",
+ "Model weights saved in outputs/checkpoint-1000/pytorch_model.bin\n",
+ "tokenizer config file saved in outputs/checkpoint-1000/tokenizer_config.json\n",
+ "Special tokens file saved in outputs/checkpoint-1000/special_tokens_map.json\n",
+ "Deleting older checkpoint [outputs/checkpoint-500] due to args.save_total_limit\n",
+ "\n",
+ "\n",
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_result = trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 103,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "Run history:
| train/epoch | โโ
โ |
| train/global_step | โโ
โ |
| train/learning_rate | โโ |
| train/loss | โโ |
| train/total_flos | โ |
| train/train_loss | โ |
| train/train_runtime | โ |
| train/train_samples_per_second | โ |
| train/train_steps_per_second | โ |
Run summary:
| train/epoch | 4 |
| train/global_step | 1444 |
| train/learning_rate | 2e-05 |
| train/loss | 0.2584 |
| train/total_flos | 2.143084255034573e+16 |
| train/train_loss | 0.38528 |
| train/train_runtime | 2161.033 |
| train/train_samples_per_second | 10.678 |
| train/train_steps_per_second | 0.668 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " View run electric-serenity-822 at: https://wandb.ai/nlp15/odqa/runs/0y11f1su
View project at: https://wandb.ai/nlp15/odqa
Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Find logs at: ./wandb/run-20241015_184122-0y11f1su/logs"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "RobertaForQuestionAnswering(\n",
+ " (roberta): RobertaModel(\n",
+ " (embeddings): RobertaEmbeddings(\n",
+ " (word_embeddings): Embedding(32000, 1024, padding_idx=1)\n",
+ " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n",
+ " (token_type_embeddings): Embedding(1, 1024)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): RobertaEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-23): 24 x RobertaLayer(\n",
+ " (attention): RobertaAttention(\n",
+ " (self): RobertaSelfAttention(\n",
+ " (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): RobertaSelfOutput(\n",
+ " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): RobertaIntermediate(\n",
+ " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): RobertaOutput(\n",
+ " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 104,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## ํ๊น
ํ์ด์ค์ ๋ชจ๋ธ ์
๋ก๋"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 105,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
+ "/bin/bash: sudo: command not found\n"
+ ]
+ }
+ ],
+ "source": [
+ "!sudo apt-get install git-lfs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Configuration saved in /tmp/tmpvxgja89o/config.json\n",
+ "Model weights saved in /tmp/tmpvxgja89o/pytorch_model.bin\n",
+ "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v2: pytorch_model.bin,config.json\n",
+ "pytorch_model.bin: 100%|โโโโโโโโโโ| 1.34G/1.34G [00:56<00:00, 23.7MB/s] \n",
+ "tokenizer config file saved in /tmp/tmp1q6nsung/tokenizer_config.json\n",
+ "Special tokens file saved in /tmp/tmp1q6nsung/special_tokens_map.json\n",
+ "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v2: tokenizer_config.json,vocab.txt,tokenizer.json,special_tokens_map.json\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "CommitInfo(commit_url='https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v2/commit/32b49af3b113769f65b0ab2392cfb81d4c159962', commit_message='Upload tokenizer', commit_description='', oid='32b49af3b113769f65b0ab2392cfb81d4c159962', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v2', endpoint='https://huggingface.co', repo_type='model', repo_id='ssunbear/klue_roberta_large_finetuned_korquad_v2'), pr_revision=None, pr_num=None)"
+ ]
+ },
+ "execution_count": 106,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from transformers import AutoModel\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "\n",
+ "# Huggingface Access Token\n",
+ "ACCESS_TOKEN = #ํ ํฐ์์ด๋ ์
๋ ฅํ์๋ฉด ๋ฉ๋๋ค.\n",
+ "\n",
+ "# Upload to Huggingface\n",
+ "model.push_to_hub('klue_roberta_large_finetuned_korquad_v2', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n",
+ "tokenizer.push_to_hub('klue_roberta_large_finetuned_korquad_v2', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "import os\n",
+ "import sys\n",
+ "import wandb\n",
+ "\n",
+ "from datasets import DatasetDict\n",
+ "import evaluate\n",
+ "import argparse\n",
+ "from trainer_qa import QuestionAnsweringTrainer\n",
+ "from transformers import (\n",
+ " AutoConfig,\n",
+ " AutoModelForQuestionAnswering,\n",
+ " AutoTokenizer,\n",
+ " DataCollatorWithPadding,\n",
+ " EvalPrediction,\n",
+ " TrainingArguments,\n",
+ ")\n",
+ "from utils_qa import set_seed, check_no_error, postprocess_qa_predictions\n",
+ "from omegaconf import OmegaConf\n",
+ "from omegaconf import DictConfig\n",
+ "from utils.naming import wandb_naming\n",
+ "from prepare_dataset import prepare_dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## FINISH\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "from CNN_layer_model import CNN_RobertaForQuestionAnswering\n",
+ "from transformers import (\n",
+ " AutoConfig,\n",
+ " AutoModelForQuestionAnswering,\n",
+ " AutoTokenizer\n",
+ ")\n",
+ "\n",
+ "model_name = \"CurtisJeon/klue-roberta-large-korquad_v1_qa\"\n",
+ "\n",
+ "config = AutoConfig.from_pretrained(\n",
+ " model_name\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
+ " model_name,\n",
+ " use_fast=True\n",
+ ")\n",
+ "model = CNN_RobertaForQuestionAnswering.from_pretrained(\n",
+ " model_name,\n",
+ " config=config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import (\n",
+ " AutoConfig,\n",
+ " AutoTokenizer\n",
+ ")\n",
+ "from custom_model_copy import CNN_RobertaForQuestionAnswering\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
+ " warnings.warn(\n",
+ "Some weights of CNN_RobertaForQuestionAnswering were not initialized from the model checkpoint at CurtisJeon/klue-roberta-large-korquad_v1_qa and are newly initialized: ['cnn_block2.layer_norm.bias', 'cnn_block4.conv2.bias', 'cnn_block5.conv2.weight', 'cnn_block2.conv1.weight', 'cnn_block2.conv1.bias', 'cnn_block4.conv2.weight', 'cnn_block4.layer_norm.bias', 'cnn_block1.layer_norm.bias', 'cnn_block2.layer_norm.weight', 'cnn_block4.layer_norm.weight', 'cnn_block4.conv1.weight', 'cnn_block5.layer_norm.bias', 'cnn_block3.layer_norm.bias', 'cnn_block4.conv1.bias', 'cnn_block1.conv2.weight', 'cnn_block1.layer_norm.weight', 'cnn_block5.layer_norm.weight', 'cnn_block3.conv2.weight', 'cnn_block3.conv2.bias', 'cnn_block5.conv1.bias', 'cnn_block3.conv1.weight', 'cnn_block2.conv2.bias', 'cnn_block3.layer_norm.weight', 'cnn_block1.conv1.weight', 'cnn_block1.conv1.bias', 'cnn_block2.conv2.weight', 'cnn_block5.conv1.weight', 'cnn_block3.conv1.bias', 'cnn_block1.conv2.bias', 'cnn_block5.conv2.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name = \"CurtisJeon/klue-roberta-large-korquad_v1_qa\"\n",
+ "\n",
+ "config = AutoConfig.from_pretrained(\n",
+ " model_name\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
+ " model_name,\n",
+ " use_fast=True\n",
+ ")\n",
+ "model = CNN_RobertaForQuestionAnswering.from_pretrained(\n",
+ " model_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CNN_RobertaForQuestionAnswering(\n",
+ " (roberta): RobertaModel(\n",
+ " (embeddings): RobertaEmbeddings(\n",
+ " (word_embeddings): Embedding(32000, 1024, padding_idx=1)\n",
+ " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n",
+ " (token_type_embeddings): Embedding(1, 1024)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): RobertaEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-23): 24 x RobertaLayer(\n",
+ " (attention): RobertaAttention(\n",
+ " (self): RobertaSelfAttention(\n",
+ " (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): RobertaSelfOutput(\n",
+ " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): RobertaIntermediate(\n",
+ " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): RobertaOutput(\n",
+ " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
+ " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (cnn_block1): CNN_block(\n",
+ " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block2): CNN_block(\n",
+ " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block3): CNN_block(\n",
+ " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block4): CNN_block(\n",
+ " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (cnn_block5): CNN_block(\n",
+ " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n",
+ " (relu): ReLU()\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/main.py b/src/main.py
new file mode 100644
index 0000000..a19f56f
--- /dev/null
+++ b/src/main.py
@@ -0,0 +1,434 @@
+import datetime
+import logging
+import os
+import sys
+import numpy as np
+from typing import Callable, List, NoReturn, Tuple
+
+from arguments import DataTrainingArguments, ModelArguments
+from datasets import (
+ Dataset,
+ DatasetDict,
+ Features,
+ Value,
+ Sequence,
+ load_from_disk,
+ load_metric
+)
+from qa_trainer import QATrainer
+from retrieval_BM25 import BM25SparseRetrieval
+from retrieval_hybridsearch import HybridSearch
+from retrieval_Dense import DenseRetrieval
+from retrieval_2s_rerank import TwoStageReranker
+
+from transformers import (
+ AutoConfig,
+ #AutoModelForQuestionAnswering
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ HfArgumentParser,
+ TrainingArguments,
+)
+from utils import set_seed, check_no_error, postprocess_qa_predictions
+import wandb
+from CNN_layer_model import CNN_RobertaForQuestionAnswering
+
+logger = logging.getLogger(__name__)
+
+
+def main():
+ parser = HfArgumentParser(
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
+ )
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ training_args.save_steps = 0
+ training_args.logging_steps = 10
+
+ project_prefix = "[train]" if training_args.do_train else "[eval]" if training_args.do_eval else "[pred]"
+ wandb.init(
+ project="odqa",
+ entity="nlp15",
+ name=f"{project_prefix} {model_args.model_name_or_path.split('/')[0]}_{(datetime.datetime.now() + datetime.timedelta(hours=9)).strftime('%Y%m%d_%H%M%S')}",
+ save_code=True,
+ )
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ logging.info(f"model is from {model_args.model_name_or_path}")
+ logging.info(f"data is from {data_args.dataset_name}")
+
+ logger.info("Training/evaluation parameters %s", training_args)
+
+ set_seed(training_args.seed)
+ print(">>> seed:", training_args.seed)
+
+
+ datasets = load_from_disk(data_args.dataset_name)
+ print(datasets)
+
+ config = AutoConfig.from_pretrained(
+ model_args.config_name
+ if model_args.config_name is not None
+ else model_args.model_name_or_path,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.tokenizer_name
+ if model_args.tokenizer_name is not None
+ else model_args.model_name_or_path,
+ use_fast=True,
+ )
+ #AutoModelForQuestionAnswering -> CNN_RobertaForQuestionAnswering : CNNlayer์ถ๊ฐํ์ฌ ์ฑ๋ฅํฅ์
+ model = CNN_RobertaForQuestionAnswering.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ )
+
+ if training_args.do_predict and data_args.eval_retrieval:
+ datasets = run_sparse_retrieval(
+ tokenizer.tokenize, datasets, training_args, data_args,
+ )
+
+ run_mrc(data_args, training_args, model_args, datasets, tokenizer, model)
+
+
+def prepare_train_features(examples, tokenizer, question_column_name, pad_on_right, context_column_name, max_seq_length, data_args, answer_column_name):
+ tokenized_examples = tokenizer(
+ examples[question_column_name if pad_on_right else context_column_name],
+ examples[context_column_name if pad_on_right else question_column_name],
+ truncation="only_second" if pad_on_right else "only_first",
+ max_length=max_seq_length,
+ stride=data_args.doc_stride,
+ return_overflowing_tokens=True,
+ return_offsets_mapping=True,
+ return_token_type_ids=False, # True if bert, False if roberta
+ padding="max_length" if data_args.pad_to_max_length else False,
+ )
+
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
+ offset_mapping = tokenized_examples.pop("offset_mapping")
+
+ tokenized_examples["start_positions"] = []
+ tokenized_examples["end_positions"] = []
+
+ for i, offsets in enumerate(offset_mapping):
+ input_ids = tokenized_examples["input_ids"][i]
+ cls_index = input_ids.index(tokenizer.cls_token_id)
+
+ sequence_ids = tokenized_examples.sequence_ids(i)
+
+ sample_index = sample_mapping[i]
+ answers = examples[answer_column_name][sample_index]
+
+ if len(answers["answer_start"]) == 0:
+ tokenized_examples["start_positions"].append(cls_index)
+ tokenized_examples["end_positions"].append(cls_index)
+ else:
+ start_char = answers["answer_start"][0]
+ end_char = start_char + len(answers["text"][0])
+
+ token_start_index = 0
+ while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
+ token_start_index += 1
+
+ token_end_index = len(input_ids) - 1
+ while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
+ token_end_index -= 1
+
+ if not (
+ offsets[token_start_index][0] <= start_char
+ and offsets[token_end_index][1] >= end_char
+ ):
+ tokenized_examples["start_positions"].append(cls_index)
+ tokenized_examples["end_positions"].append(cls_index)
+ else:
+ while (
+ token_start_index < len(offsets)
+ and offsets[token_start_index][0] <= start_char
+ ):
+ token_start_index += 1
+ tokenized_examples["start_positions"].append(token_start_index - 1)
+ while offsets[token_end_index][1] >= end_char:
+ token_end_index -= 1
+ tokenized_examples["end_positions"].append(token_end_index + 1)
+
+ return tokenized_examples
+
+
+def run_sparse_retrieval(
+ tokenize_fn: Callable[[str], List[str]],
+ datasets: DatasetDict,
+ training_args: TrainingArguments,
+ data_args: DataTrainingArguments,
+ data_path: str = "../data",
+ context_path: str = "wikipedia_documents.json",
+) -> DatasetDict:
+
+ retriever = HybridSearch(
+ tokenize_fn=tokenize_fn,
+ # args=data_args, # args๋ฅผ ์ ๋ฌ
+ data_path=data_path,
+ context_path=context_path
+ )
+ retriever.get_sparse_embedding()
+ retriever.get_dense_embedding()
+
+ # retriever = BM25SparseRetrieval(
+ # tokenize_fn=tokenize_fn,
+ # data_path=data_path,
+ # context_path=context_path
+ # )
+ # retriever.get_sparse_embedding()
+
+ # retriever = TwoStageReranker(
+ # tokenize_fn=tokenize_fn,
+ # args=data_args, # args๋ฅผ ์ ๋ฌ
+ # data_path=data_path,
+ # context_path=context_path
+ # )
+
+ # retriever = DenseRetrieval(
+ # data_path=data_path,
+ # context_path=context_path
+ # )
+ # retriever.get_dense_embedding()
+
+ # if data_args.use_faiss:
+ # retriever.build_faiss(num_clusters=data_args.num_clusters)
+ # df = retriever.retrieve_faiss(
+ # datasets["validation"], topk=data_args.top_k_retrieval
+ # )
+ # else:
+
+ # df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval)
+ # df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval, alpha=data_args.alpha_retrieval)
+ df = retriever.retrieve(datasets["validation"], topk=10, alpha=0.7)
+
+
+ if training_args.do_predict:
+ f = Features(
+ {
+ "context": Value(dtype="string", id=None),
+ "id": Value(dtype="string", id=None),
+ "question": Value(dtype="string", id=None),
+ }
+ )
+
+ elif training_args.do_eval:
+ f = Features(
+ {
+ "answers": Sequence(
+ feature={
+ "text": Value(dtype="string", id=None),
+ "answer_start": Value(dtype="int32", id=None),
+ },
+ length=-1,
+ id=None,
+ ),
+ "context": Value(dtype="string", id=None),
+ "id": Value(dtype="string", id=None),
+ "question": Value(dtype="string", id=None),
+ }
+ )
+ datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
+ return datasets
+
+
+def prepare_validation_features(examples, tokenizer, question_column_name, pad_on_right, context_column_name, max_seq_length, data_args, answer_column_name):
+ tokenized_examples = tokenizer(
+ examples[question_column_name if pad_on_right else context_column_name],
+ examples[context_column_name if pad_on_right else question_column_name],
+ truncation="only_second" if pad_on_right else "only_first",
+ max_length=max_seq_length,
+ stride=data_args.doc_stride,
+ return_overflowing_tokens=True,
+ return_offsets_mapping=True,
+ return_token_type_ids=False, # True if bert, False if roberta
+ padding="max_length" if data_args.pad_to_max_length else False,
+ )
+
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
+
+ tokenized_examples["example_id"] = []
+
+ for i in range(len(tokenized_examples["input_ids"])):
+ sequence_ids = tokenized_examples.sequence_ids(i)
+ context_index = 1 if pad_on_right else 0
+
+ sample_index = sample_mapping[i]
+ tokenized_examples["example_id"].append(examples["id"][sample_index])
+
+ tokenized_examples["offset_mapping"][i] = [
+ (o if sequence_ids[k] == context_index else None)
+ for k, o in enumerate(tokenized_examples["offset_mapping"][i])
+ ]
+ return tokenized_examples
+
+
+def run_mrc(
+ data_args: DataTrainingArguments,
+ training_args: TrainingArguments,
+ model_args: ModelArguments,
+ datasets: DatasetDict,
+ tokenizer,
+ model,
+) -> NoReturn:
+ if training_args.do_train:
+ column_names = datasets["train"].column_names
+ else:
+ column_names = datasets["validation"].column_names
+
+ question_column_name = "question" if "question" in column_names else column_names[0]
+ context_column_name = "context" if "context" in column_names else column_names[1]
+ answer_column_name = "answers" if "answers" in column_names else column_names[2]
+
+ pad_on_right = tokenizer.padding_side == "right"
+
+ last_checkpoint, max_seq_length = check_no_error(
+ data_args, training_args, datasets, tokenizer
+ )
+
+ if training_args.do_train:
+ if "train" not in datasets:
+ raise ValueError("--do_train requires a train dataset")
+ train_dataset = datasets["train"]
+
+ train_dataset = train_dataset.map(
+ prepare_train_features,
+ fn_kwargs={
+ 'tokenizer': tokenizer,
+ 'pad_on_right': pad_on_right,
+ 'max_seq_length': max_seq_length,
+ 'data_args': data_args,
+ 'question_column_name': question_column_name,
+ 'context_column_name': context_column_name,
+ 'answer_column_name': answer_column_name
+ },
+ batched=True,
+ num_proc=data_args.preprocessing_num_workers,
+ remove_columns=column_names,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ if training_args.do_eval or training_args.do_predict:
+ eval_dataset = datasets["validation"]
+
+ eval_dataset = eval_dataset.map(
+ prepare_validation_features,
+ fn_kwargs={
+ 'tokenizer': tokenizer,
+ 'pad_on_right': pad_on_right,
+ 'max_seq_length': max_seq_length,
+ 'data_args': data_args,
+ 'question_column_name': question_column_name,
+ 'context_column_name': context_column_name,
+ 'answer_column_name': answer_column_name
+ },
+ batched=True,
+ num_proc=data_args.preprocessing_num_workers,
+ remove_columns=column_names,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ data_collator = DataCollatorWithPadding(
+ tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None
+ )
+
+ metric = load_metric("squad")
+
+ def post_processing_function(
+ examples,
+ features,
+ predictions: Tuple[np.ndarray, np.ndarray],
+ training_args: TrainingArguments
+ ) -> EvalPrediction:
+ predictions = postprocess_qa_predictions(
+ examples=examples,
+ features=features,
+ predictions=predictions,
+ max_answer_length=data_args.max_answer_length,
+ output_dir=training_args.output_dir,
+ )
+ formatted_predictions = [
+ {"id": k, "prediction_text": v} for k, v in predictions.items()
+ ]
+
+ if training_args.do_predict:
+ return formatted_predictions
+
+ elif training_args.do_eval:
+ references = [
+ {"id": ex["id"], "answers": ex[answer_column_name]}
+ for ex in datasets["validation"]
+ ]
+ return EvalPrediction(
+ predictions=formatted_predictions, label_ids=references
+ )
+
+ trainer = QATrainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset if training_args.do_train else None,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ eval_examples=datasets["validation"] if training_args.do_eval else None,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ post_process_function=post_processing_function,
+ compute_metrics=lambda x: metric.compute(predictions=x.predictions, references=x.label_ids),
+ )
+
+ if training_args.do_train:
+ if last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ elif os.path.isdir(model_args.model_name_or_path):
+ checkpoint = model_args.model_name_or_path
+ else:
+ checkpoint = None
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ trainer.save_model()
+
+ metrics = train_result.metrics
+ metrics["train_samples"] = len(train_dataset)
+
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
+
+ with open(output_train_file, "w") as writer:
+ logger.info("***** Train results *****")
+ for key, value in sorted(train_result.metrics.items()):
+ logger.info(f" {key} = {value}")
+ writer.write(f"{key} = {value}\n")
+
+ trainer.state.save_to_json(
+ os.path.join(training_args.output_dir, "trainer_state.json")
+ )
+
+ logger.info("*** Evaluate ***")
+
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+ metrics = trainer.evaluate()
+
+ metrics["eval_samples"] = len(eval_dataset)
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ if training_args.do_predict:
+ predictions = trainer.predict(
+ test_dataset=eval_dataset, test_examples=datasets["validation"]
+ )
+ logger.info("No metric can be presented because there is no correct answer given. Job done!")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/src/optimize_retriever.py b/src/optimize_retriever.py
new file mode 100644
index 0000000..9889a84
--- /dev/null
+++ b/src/optimize_retriever.py
@@ -0,0 +1,124 @@
+from rank_bm25 import BM25Plus
+import optuna
+from sklearn.metrics import ndcg_score
+import numpy as np
+from datasets import load_from_disk
+from transformers import AutoTokenizer, AutoModel
+from torch.nn.functional import normalize
+import torch
+from tqdm import tqdm
+
+import logging
+import datetime
+import wandb
+
+from retrieval_hybridsearch import HybridSearch # HybridSearch ํด๋์ค๊ฐ ์ ์๋ ํ์ผ์์ ์ํฌํธ
+from retrieval_2s_rerank import TwoStageReranker
+
+# ๋ก๊ทธ ์ค์
+logger = logging.getLogger(__name__)
+wandb.init(project="odqa",
+ name="run_" + (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime("%Y%m%d_%H%M%S"),
+ entity="nlp15"
+ )
+
+datasets = load_from_disk("../data/train_dataset")
+
+documents = datasets['train']['context']
+queries = datasets['train']['question']
+
+tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned")
+dense_tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct')
+dense_embeder = AutoModel.from_pretrained(
+ 'intfloat/multilingual-e5-large-instruct'
+ )
+
+dense_embeds = []
+batch_size = 64
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+dense_embeder.to(device)
+
+def mean_pooling(model_output, attention_mask):
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+for i in tqdm(range(0, len(documents), batch_size), desc="Encoding passages"):
+ batch_contexts = documents[i:i+batch_size]
+ encoded_input = dense_tokenizer(
+ batch_contexts, padding=True, truncation=True, return_tensors='pt'
+ ).to(device)
+ with torch.no_grad():
+ model_output = dense_embeder(**encoded_input)
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ sentence_embeddings = normalize(sentence_embeddings, p=2, dim=1)
+ dense_embeds.append(sentence_embeddings.cpu())
+ del encoded_input, model_output, sentence_embeddings
+ torch.cuda.empty_cache()
+
+dense_embeds = torch.cat(dense_embeds, dim=0)
+
+retriever = HybridSearch(
+ tokenize_fn=tokenizer.tokenize,
+ data_path="../data",
+ context_path="wikipedia_documents.json"
+)
+# retriever = TwoStageReranker(
+# tokenize_fn=tokenizer.tokenize,
+# data_path="../data",
+# context_path="wikipedia_documents.json"
+# )
+
+retriever.get_dense_embedding()
+retriever.get_sparse_embedding()
+
+true_relevance_scores = np.eye(len(documents), dtype=int).tolist()
+
+retriever.dense_embeds = dense_embeds
+# retriever.dense_embeder.dense_embeds = dense_embeds
+
+def objective(trial):
+ alpha = trial.suggest_float("alpha", 0.4, 1.0)
+ k1 = trial.suggest_float("k1", 0.5, 2.0)
+ b = trial.suggest_float("b", 0.0, 1.0)
+ delta = trial.suggest_float("delta", 0.0, 1.0)
+
+ retriever.sparse_embeder = BM25Plus([tokenizer.tokenize(doc) for doc in documents], k1=k1, b=b, delta=delta)
+
+ all_scores = []
+ all_doc_indices = []
+
+ for idx, query in enumerate(queries):
+ scores, contexts, doc_indices = retriever.retrieve(query, topk=20, alpha=alpha)
+ all_scores.append(scores)
+ all_doc_indices.append(doc_indices)
+
+ true_relevance_scores = []
+ for idx, doc_indices in enumerate(all_doc_indices):
+ relevance = [1 if doc_idx == idx else 0 for doc_idx in doc_indices]
+ true_relevance_scores.append(relevance)
+
+ all_scores = np.array(all_scores)
+ true_relevance_scores = np.array(true_relevance_scores)
+
+ avg_ndcg = ndcg_score(true_relevance_scores, all_scores)
+
+ return avg_ndcg
+
+class TQDMProgressBar:
+ def __init__(self, n_trials):
+ self.pbar = tqdm(total=n_trials)
+
+ def __call__(self, study, trial):
+ self.pbar.update(1)
+
+n_trials = 30 # ์ด ์๋ ํ์
+progress_bar = TQDMProgressBar(n_trials)
+study = optuna.create_study(direction="maximize")
+study.optimize(objective, n_trials=n_trials, callbacks=[progress_bar])
+
+# ์ต์ ์ ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ๋ ฅ
+print("Best alpha:", study.best_params["alpha"])
+print("Best k1:", study.best_params["k1"])
+print("Best b:", study.best_params["b"])
+print("Best delta:", study.best_params["delta"])
\ No newline at end of file
diff --git a/src/preprocess_answer.ipynb b/src/preprocess_answer.ipynb
new file mode 100644
index 0000000..0251a41
--- /dev/null
+++ b/src/preprocess_answer.ipynb
@@ -0,0 +1,162 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import torch\n",
+ "\n",
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "from datasets import load_from_disk"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_dir = \"/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ๋ฐ์ดํฐ์
๋ก๋\n",
+ "dataset_dict = load_from_disk(os.path.join(base_dir, \"data\", \"test_dataset\"))\n",
+ "\n",
+ "# 'test' ๋ฐ์ดํฐ์
์ ์ ํํ๊ณ Pandas ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํ\n",
+ "test_dataset = dataset_dict['validation']\n",
+ "df = test_dataset.to_pandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_id = \"rtzr/ko-gemma-2-9b-it\"\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_id,\n",
+ " torch_dtype=torch.bfloat16,\n",
+ " device_map=\"auto\",\n",
+ ")\n",
+ "\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "instruction = \"\"\"๋ค์์ '์ ์ ๋ต๋ณ'์์ ํ์ํ ๋ต๋ง ๋์ผํ๊ฒ ์ถ์ถํ์ฌ ์ ์ํ์์ค. ์ถ๊ฐ์ ์ธ ์ ๋ณด๋ ์์ ์์ด, '์ ์ ๋ต๋ณ'์ ๋ด์ฉ์ ์ฌ์ฉํ์์ค.\n",
+ "\n",
+ "\n",
+ "๋ค์์ ๋ฐ๋์ ์งํค์์ค: ํ์์ ์กฐ์ฌ, ์ด๋ฏธ๋ฅผ ์ ๊ฑฐํ ์ ์์ผ๋, '์ ์ ๋ต๋ณ' ์ด์ธ์ ๋ด์ฉ์ ์์ฑ, ์ถ๊ฐํ๋ฉด ์ ๋ ์๋จ. ์ค์ ํต์ฌ ๋ด์ฉ์ ๋จ๊ธฐ๊ณ ์ ๊ฑฐํ ์ ์๋ค.\n",
+ "\n",
+ "### ์์ ###\n",
+ "์ง๋ฌธ: ์ 2์ฐจ ์ธ๊ณ ๋์ ์ ๋ช ๋
์ ๋ฐ๋ฐํ์๋๊ฐ? ์ ์ ๋ต๋ณ: 1939๋
9์ 1์ผ์ ๋ฐ๋ฐํ์๋ค.\n",
+ "์์ฑ ๋ต๋ณ: 1939๋
\n",
+ "\n",
+ "### ์ง๋ฌธ ###\n",
+ "{}\n",
+ "\n",
+ "### ์ ์ ๋ต๋ณ ###\n",
+ "{}\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"predictions.json\", \"r\") as f:\n",
+ " predictions = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_pred = {}\n",
+ "\n",
+ "for i in range(len(df)):\n",
+ " question = df.loc[i, 'question']\n",
+ "\n",
+ " messages = [\n",
+ " {\"role\": \"user\", \"content\": instruction.format(question, predictions[df.loc[i, \"id\"]])},\n",
+ " ]\n",
+ "\n",
+ " input_ids = tokenizer.apply_chat_template(\n",
+ " messages,\n",
+ " add_generation_prompt=True,\n",
+ " return_tensors=\"pt\"\n",
+ " ).to(model.device)\n",
+ "\n",
+ " terminators = [\n",
+ " tokenizer.eos_token_id,\n",
+ " tokenizer.convert_tokens_to_ids(\"\")\n",
+ " ]\n",
+ "\n",
+ " outputs = model.generate(\n",
+ " input_ids,\n",
+ " max_new_tokens=2048,\n",
+ " eos_token_id=terminators,\n",
+ " do_sample=True,\n",
+ " temperature=0.1,\n",
+ " top_p=0.9,\n",
+ " )\n",
+ "\n",
+ " pred = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)\n",
+ "\n",
+ " new_pred[df.loc[i, \"id\"]] = pred.strip()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(os.path.join(base_dir, \"data\", \"predictions_fixed.json\"), \"w\") as f:\n",
+ " json.dump(new_pred, f, ensure_ascii=False, indent=4)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/qa_trainer.py b/src/qa_trainer.py
new file mode 100644
index 0000000..a74543f
--- /dev/null
+++ b/src/qa_trainer.py
@@ -0,0 +1,78 @@
+from transformers import Trainer, is_datasets_available
+
+if is_datasets_available():
+ import datasets
+
+
+class QATrainer(Trainer):
+ def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.eval_examples = eval_examples
+ self.post_process_function = post_process_function
+
+ def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):
+ eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
+ eval_examples = self.eval_examples if eval_examples is None else eval_examples
+
+ compute_metrics = self.compute_metrics
+ self.compute_metrics = None
+ try:
+ output = self.prediction_loop(
+ eval_dataloader,
+ description="Evaluation",
+ prediction_loss_only=True if compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ )
+ finally:
+ self.compute_metrics = compute_metrics
+
+ if isinstance(eval_dataset, datasets.Dataset):
+ eval_dataset.set_format(
+ type=eval_dataset.format["type"],
+ columns=list(eval_dataset.features.keys()),
+ )
+
+ if self.post_process_function is not None and self.compute_metrics is not None:
+ eval_preds = self.post_process_function(
+ eval_examples, eval_dataset, output.predictions, self.args
+ )
+ metrics = self.compute_metrics(eval_preds)
+
+ self.log(metrics)
+ else:
+ metrics = {}
+
+ self.control = self.callback_handler.on_evaluate(
+ self.args, self.state, self.control, metrics
+ )
+ return metrics
+
+ def predict(self, test_dataset, test_examples, ignore_keys=None):
+ test_dataloader = self.get_test_dataloader(test_dataset)
+
+ compute_metrics = self.compute_metrics
+ self.compute_metrics = None
+ try:
+ output = self.prediction_loop(
+ test_dataloader,
+ description="Evaluation",
+ prediction_loss_only=True if compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ )
+ finally:
+ self.compute_metrics = compute_metrics
+
+ if self.post_process_function is None or self.compute_metrics is None:
+ return output
+
+ if isinstance(test_dataset, datasets.Dataset):
+ test_dataset.set_format(
+ type=test_dataset.format["type"],
+ columns=list(test_dataset.features.keys()),
+ )
+
+ predictions = self.post_process_function(
+ test_examples, test_dataset, output.predictions, self.args
+ )
+ return predictions
\ No newline at end of file
diff --git a/src/retrieval.py b/src/retrieval.py
new file mode 100644
index 0000000..9425feb
--- /dev/null
+++ b/src/retrieval.py
@@ -0,0 +1,86 @@
+import datetime
+import logging
+import os
+import sys
+import numpy as np
+from typing import Callable, List, NoReturn, Tuple
+
+from arguments import DataTrainingArguments, ModelArguments
+from datasets import (
+ Dataset,
+ DatasetDict,
+ Features,
+ Value,
+ Sequence,
+ load_from_disk,
+ load_metric
+)
+from qa_trainer import QATrainer
+from retrieval_BM25 import BM25SparseRetrieval
+from retrieval_hybridsearch import HybridSearch
+from retrieval_Dense import DenseRetrieval
+from retrieval_2s_rerank import TwoStageReranker
+
+from transformers import (
+ AutoConfig,
+ AutoModelForQuestionAnswering,
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ HfArgumentParser,
+ TrainingArguments,
+)
+from utils import set_seed, check_no_error, postprocess_qa_predictions
+
+
+class Retriever:
+ def __init__(
+ self,
+ tokenize_fn,
+ args,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ name: str
+ ):
+ self.retriever = None
+ if name == "2s_rerank":
+ self.retriever = TwoStageReranker(
+ tokenize_fn=tokenize_fn,
+ args=data_args,
+ data_path=data_path,
+ context_path=context_path
+ )
+ elif name == "BM25":
+ self.retriever = BM25SparseRetrieval(
+ tokenize_fn=tokenize_fn,
+ # args=data_args,
+ data_path=data_path,
+ context_path=context_path
+ )
+ self.retriever.get_sparse_embedding()
+ elif name == "Dense":
+ self.retriever = DenseRetrieval(
+ data_path=data_path,
+ context_path=context_path
+ )
+ self.retriever.get_dense_embedding()
+ elif name == "hybridsearch":
+ self.retriever = HybridSearch(
+ tokenize_fn=tokenize_fn,
+ # args=data_args,
+ data_path=data_path,
+ context_path=context_path
+ )
+ self.retriever.get_sparse_embedding()
+ self.retriever.get_dense_embedding()
+ elif name == "tfidf":
+ self.retriever = TFIDFRetrieval(
+ tokenize_fn=tokenize_fn,
+ args=data_args,
+ data_path=data_path,
+ context_path=context_path
+ )
+ self.retriever.get_sparse_embedding()
+
+ def retrieve(self, query_or_dataset):
+ return self.retriever.retrieve(query_or_dataset)
\ No newline at end of file
diff --git a/src/retrieval_2s_rerank.py b/src/retrieval_2s_rerank.py
new file mode 100644
index 0000000..834b32a
--- /dev/null
+++ b/src/retrieval_2s_rerank.py
@@ -0,0 +1,146 @@
+import json
+import os
+import pickle
+import time
+import torch
+import logging
+import scipy
+from contextlib import contextmanager
+from typing import List, Optional, Tuple, Union, NoReturn
+from tqdm.auto import tqdm
+
+import argparse
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_from_disk
+from torch.nn.functional import normalize
+
+from retrieval_BM25 import BM25SparseRetrieval
+from retrieval_Dense import DenseRetrieval
+
+from sklearn.feature_extraction.text import TfidfVectorizer
+from rank_bm25 import BM25Okapi, BM25Plus
+from transformers import AutoTokenizer, AutoModel
+from utils import set_seed
+
+set_seed(42)
+logger = logging.getLogger(__name__)
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ logging.info(f"[{name}] done in {time.time() - t0:.3f} s")
+
+
+class TwoStageReranker:
+ def __init__(
+ self,
+ tokenize_fn,
+ # args,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ ) -> NoReturn:
+ self.data_path = data_path
+ with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
+ wiki = json.load(f)
+
+ self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))
+ logging.info(f"Lengths of contexts : {len(self.contexts)}")
+ self.ids = list(range(len(self.contexts)))
+
+ self.sparse_embeder = BM25SparseRetrieval(
+ tokenize_fn=tokenize_fn,
+ # args=args,
+ data_path=data_path,
+ context_path=context_path
+ )
+ self.dense_embeder = DenseRetrieval(
+ data_path=data_path,
+ context_path=context_path
+ )
+ self.sparse_embeds_bool = False
+ self.dense_embeds_bool = False
+
+ def retrieve_first(self, queries, topk: Optional[int] = 1):
+ if self.sparse_embeds_bool == False:
+ self.sparse_embeder.get_sparse_embedding()
+ self.sparse_embeds_bool = True
+ f_df = self.sparse_embeder.retrieve(queries, topk=topk)
+ return f_df
+
+ def retireve_second(self, queries, topk: Optional[int] = 1, contexts=None):
+ self.dense_embeder.get_dense_embedding(contexts=contexts)
+ s_df = self.dense_embeder.retrieve(queries, topk=topk)
+ # self.sparse_embeder.get_sparse_embedding(contexts=contexts)
+ # s_df = self.sparse_embeder.retrieve(queries, topk=topk)
+ return s_df
+
+ def retrieve(self, query_or_dataset, topk: Optional[int] = 1):
+ retrieved_contexts = []
+ if isinstance(query_or_dataset, str):
+ _, doc_indices = self.retrieve_first(query_or_dataset, topk)
+ retrieved_contexts = doc_indices
+ elif isinstance(query_or_dataset, Dataset):
+ for idx, example in enumerate(tqdm(query_or_dataset, desc="Sparse retrieval: ")):
+ _, doc_indices = self.retrieve_first(example['question'], topk)
+ retrieved_contexts.append(doc_indices)
+
+ half_topk = int(topk / 3)
+
+ if isinstance(query_or_dataset, str):
+ second_df = self.retireve_second(query_or_dataset, half_topk, contexts=retrieved_contexts)
+ return second_df
+ elif isinstance(query_or_dataset, Dataset):
+ second_df = []
+ for i, example in enumerate(query_or_dataset):
+ context = retrieved_contexts[i]
+ doc_scores, doc_indices = self.retireve_second(example['question'], half_topk, contexts=context)
+ tmp = {
+ "question": example["question"],
+ "id": example["id"],
+ "context": " ".join(doc_indices),
+ }
+ second_df.append(tmp)
+ second_df = pd.DataFrame(second_df)
+ return second_df
+
+if __name__ == "__main__":
+ import argparse
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--dataset_name", default="../data/train_dataset", type=str)
+ parser.add_argument("--data_path", default="../data", type=str)
+ parser.add_argument("--context_path", default="wikipedia_documents.json", type=str)
+
+ args = parser.parse_args()
+ logging.info(args.__dict__)
+
+ org_dataset = load_from_disk(args.dataset_name)
+ full_ds = concatenate_datasets(
+ [
+ org_dataset["train"].flatten_indices(),
+ org_dataset["validation"].flatten_indices(),
+ ]
+ )
+ logging.info("*" * 40 + " query dataset " + "*" * 40)
+ logging.info(f"Full dataset: {full_ds}")
+
+ tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned")
+
+ # query = "๋ํต๋ น์ ํฌํจํ ๋ฏธ๊ตญ์ ํ์ ๋ถ ๊ฒฌ์ ๊ถ์ ๊ฐ๋ ๊ตญ๊ฐ ๊ธฐ๊ด์?"
+ query = "์ ๋ น์ ์ด๋ ํ์ฑ์์ ์ง๊ตฌ๋ก ์๋๊ฐ?"
+
+ retriever = TwoStageReranker(
+ tokenize_fn=tokenizer.tokenize,
+ args=args,
+ data_path=args.data_path,
+ context_path=args.context_path,
+ )
+
+ with timer("single query by exhaustive search"):
+ doc_scores, doc_indices = retriever.retrieve(query, topk=5)
\ No newline at end of file
diff --git a/src/retrieval_BM25.py b/src/retrieval_BM25.py
new file mode 100644
index 0000000..7d64a4b
--- /dev/null
+++ b/src/retrieval_BM25.py
@@ -0,0 +1,181 @@
+import json
+import os
+import pickle
+import time
+from contextlib import contextmanager
+from typing import List, Optional, Tuple, Union
+
+import argparse
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_from_disk
+from rank_bm25 import BM25Plus
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer
+
+from utils import set_seed
+
+set_seed(42)
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ print(f"[{name}] done in {time.time() - t0:.3f} s")
+
+
+class BM25SparseRetrieval:
+ def __init__(self, tokenize_fn, args, data_path: Optional[str] = "../data/", context_path: Optional[str] = "wikipedia_documents.json") -> None:
+ set_seed(42)
+ def __init__(
+ self,
+ tokenize_fn,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ corpus: Optional[pd.DataFrame] = None
+ ) -> None:
+ self.tokenizer = tokenize_fn
+ self.data_path = data_path
+ self.args = args
+
+ # ์ํค ๋ฌธ์ ๋ก๋
+ with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
+ wiki = json.load(f)
+
+ self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))
+ print(f"Lengths of unique contexts : {len(self.contexts)}")
+
+ self.bm25 = None
+
+
+ def get_sparse_embedding(self, contexts=None) -> None:
+ """BM25+๋ก Passage Embedding์ ๋ง๋ค๊ณ ์ด๊ธฐํํฉ๋๋ค."""
+ pickle_name = "bm25plus_sparse_embedding_optuna.bin"
+ emd_path = os.path.join(self.data_path, pickle_name)
+ if contexts is not None:
+ self.contexts = contexts
+ tokenized_corpus = [self.tokenizer(doc) for doc in self.contexts]
+ self.bm25 = BM25Plus(tokenized_corpus, k1=1.7595, b=0.9172, delta=1.1490)
+ else:
+ if os.path.isfile(emd_path):
+ with open(emd_path, "rb") as file:
+ self.bm25 = pickle.load(file)
+ print("Embedding pickle load.")
+ else:
+ print("Build passage embedding")
+ tokenized_corpus = [self.tokenizer(doc) for doc in self.contexts]
+ self.bm25 = BM25Plus(tokenized_corpus, k1=1.7595, b=0.9172, delta=1.1490) # BM25Plus๋ก ๋ณ๊ฒฝ ํ ํ์ดํผํ๋ผ๋ฏธํฐ Optuna test1 ์ ์ฉ
+ with open(emd_path, "wb") as file:
+ pickle.dump(self.bm25, file)
+ print("Embedding pickle saved.")
+
+
+ def retrieve(self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1) -> Union[Tuple[List, List], pd.DataFrame]:
+ assert self.bm25 is not None, "get_sparse_embedding() ๋ฉ์๋๋ฅผ ๋จผ์ ์ํํด์ค์ผํฉ๋๋ค."
+
+ if isinstance(query_or_dataset, str):
+ doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk)
+ # print("[Search query]\n", query_or_dataset, "\n")
+
+ # for i in range(topk):
+ # print(f"Top-{i+1} passage with score {doc_scores[0][i]:4f}")
+ # print(self.contexts[doc_indices[0][i]])
+
+ return (doc_scores, [self.contexts[doc_indices[0][i]] for i in range(topk)])
+
+ elif isinstance(query_or_dataset, Dataset):
+ # Retrieveํ Passage๋ฅผ pd.DataFrame์ผ๋ก ๋ฐํํฉ๋๋ค.
+ with timer("query exhaustive search"):
+ doc_scores, doc_indices = self.get_relevant_doc_bulk(query_or_dataset["question"], k=topk)
+
+ total = []
+ for idx, example in enumerate(tqdm(query_or_dataset, desc="Sparse retrieval: ")):
+ tmp = {
+ "question": example["question"],
+ "id": example["id"],
+ "context": " ".join([self.contexts[pid] for pid in doc_indices[idx]]),
+ }
+ if "context" in example.keys() and "answers" in example.keys():
+ tmp["original_context"] = example["context"]
+ tmp["answers"] = example["answers"]
+ total.append(tmp)
+
+ return pd.DataFrame(total)
+
+
+ def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]:
+ """๊ฐ๋ณ ์ง์์ ๋ํ ์์ k๊ฐ์ Passage ๊ฒ์"""
+ tokenized_query = [self.tokenizer(query)]
+ result = np.array([self.bm25.get_scores(query) for query in tokenized_query])
+ doc_scores = []
+ doc_indices = []
+
+ for scores in result:
+ sorted_result = np.argsort(scores)[-k:][::-1]
+ doc_scores.append(scores[sorted_result].tolist())
+ doc_indices.append(sorted_result.tolist())
+
+ return doc_scores, doc_indices
+
+
+ def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 1) -> Tuple[List, List]:
+ """์ฌ๋ฌ ๊ฐ์ Query๋ฅผ ๋ฐ์ ์์ k๊ฐ์ Passage ๊ฒ์"""
+ tokenized_queries = [self.tokenizer(query) for query in queries]
+ result = np.array([self.bm25.get_scores(query) for query in tokenized_queries])
+ doc_scores = []
+ doc_indices = []
+
+ for scores in result:
+ sorted_result = np.argsort(scores)[-k:][::-1]
+ doc_scores.append(scores[sorted_result].tolist())
+ doc_indices.append(sorted_result.tolist())
+
+ return doc_scores, doc_indices
+
+
+if __name__ == "__main__":
+ import argparse
+ import logging
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--dataset_name", default="../data/train_dataset", type=str)
+ parser.add_argument("--data_path", default="../data", type=str)
+ parser.add_argument("--context_path", default="wikipedia_documents.json", type=str)
+
+ args = parser.parse_args()
+ logging.info(args.__dict__)
+
+ org_dataset = load_from_disk(args.dataset_name)
+ full_ds = concatenate_datasets(
+ [
+ org_dataset["train"].flatten_indices(),
+ org_dataset["validation"].flatten_indices(),
+ ]
+ )
+ logging.info("*" * 40 + " query dataset " + "*" * 40)
+ logging.info(f"Full dataset: {full_ds}")
+
+ tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned")
+
+ retriever = BM25SparseRetrieval(
+ tokenize_fn=tokenizer.tokenize,
+ args=args,
+ data_path=args.data_path,
+ context_path=args.context_path,
+ )
+ retriever.get_sparse_embedding()
+
+ # query = "๋ํต๋ น์ ํฌํจํ ๋ฏธ๊ตญ์ ํ์ ๋ถ ๊ฒฌ์ ๊ถ์ ๊ฐ๋ ๊ตญ๊ฐ ๊ธฐ๊ด์?"
+ query = "์ ๋ น์ ์ด๋ ํ์ฑ์์ ์ง๊ตฌ๋ก ์๋๊ฐ?"
+
+ # test single query
+ with timer("single query by exhaustive search using bm25"):
+ scores, indices = retriever.retrieve(query, 20)
+ for i, context in enumerate(indices):
+ print(f"Top-{i} ์ ๋ฌธ์์
๋๋ค. ")
+ print("---------------------------------------------")
+ print(context)
\ No newline at end of file
diff --git a/src/retrieval_Dense.py b/src/retrieval_Dense.py
new file mode 100644
index 0000000..ebe0d0d
--- /dev/null
+++ b/src/retrieval_Dense.py
@@ -0,0 +1,229 @@
+import json
+import os
+import time
+import logging
+from contextlib import contextmanager
+from typing import List, NoReturn, Optional, Tuple, Union
+
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_from_disk
+from tqdm.auto import tqdm
+
+import torch
+import torch.nn.functional as F
+from transformers import AutoTokenizer, AutoModel
+
+from utils import set_seed
+
+set_seed(42)
+logger = logging.getLogger(__name__)
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ logging.info(f"[{name}] done in {time.time() - t0:.3f} s")
+
+def mean_pooling(model_output, attention_mask):
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+class DenseRetrieval:
+ def __init__(
+ self,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ corpus: Optional[pd.DataFrame] = None
+ ) -> NoReturn:
+ self.data_path = data_path
+ with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
+ wiki = json.load(f)
+
+ self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))
+ logging.info(f"Lengths of contexts : {len(self.contexts)}")
+ self.ids = list(range(len(self.contexts)))
+
+ self.tokenize_fn = AutoTokenizer.from_pretrained(
+ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
+ )
+ self.dense_embeder = AutoModel.from_pretrained(
+ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
+ )
+ self.dense_embeds = None
+
+ def get_dense_embedding(self, question=None, contexts=None):
+ if contexts is not None:
+ self.contexts = contexts
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.dense_embeder.to(device)
+ encoded_input = self.tokenize_fn(
+ self.contexts, padding=True, truncation=True, return_tensors='pt'
+ ).to(device)
+ with torch.no_grad():
+ model_output = self.dense_embeder(**encoded_input)
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ self.dense_embeds = sentence_embeddings.cpu()
+
+ if question is None and contexts is None:
+ pickle_name = "dense_without_normalize_embedding.bin"
+ emd_path = os.path.join(self.data_path, pickle_name)
+
+ if os.path.isfile(emd_path):
+ self.dense_embeds = torch.load(emd_path)
+ print("Dense embedding loaded.")
+ else:
+ print("Building passage dense embeddings in batches.")
+ self.dense_embeds = []
+ batch_size = 64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.dense_embeder.to(device)
+
+ for i in tqdm(range(0, len(self.contexts), batch_size), desc="Encoding passages"):
+ batch_contexts = self.contexts[i:i+batch_size]
+ encoded_input = self.tokenize_fn(
+ batch_contexts, padding=True, truncation=True, return_tensors='pt'
+ ).to(device)
+ with torch.no_grad():
+ model_output = self.dense_embeder(**encoded_input)
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ self.dense_embeds.append(sentence_embeddings.cpu())
+ del encoded_input, model_output, sentence_embeddings
+ torch.cuda.empty_cache()
+
+ self.dense_embeds = torch.cat(self.dense_embeds, dim=0)
+ torch.save(self.dense_embeds, emd_path)
+ print("Dense embeddings saved.")
+ elif question is not None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.dense_embeder.to(device)
+ encoded_input = self.tokenize_fn(
+ question, padding=True, truncation=True, return_tensors='pt'
+ ).to(device)
+ with torch.no_grad():
+ model_output = self.dense_embeder(**encoded_input)
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ return sentence_embeddings.cpu()
+
+ def get_similarity_score(self, q_vec, c_vec):
+ if isinstance(q_vec, scipy.sparse.spmatrix):
+ q_vec = q_vec.toarray()
+ if isinstance(c_vec, scipy.sparse.spmatrix):
+ c_vec = c_vec.toarray()
+
+ q_vec = torch.tensor(q_vec)
+ c_vec = torch.tensor(c_vec)
+ return q_vec.matmul(c_vec.T).numpy()
+
+ def get_cosine_score(self, q_vec, c_vec):
+ q_vec = q_vec / q_vec.norm(dim=1, keepdim=True)
+ c_vec = c_vec / c_vec.norm(dim=1, keepdim=True)
+ return torch.mm(q_vec, c_vec.T).numpy()
+
+ def retrieve(
+ self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
+ ) -> Union[Tuple[List[float], List[str]], pd.DataFrame]:
+ assert self.dense_embeds is not None, "You should first execute `get_sparse_embedding()`"
+
+ if isinstance(query_or_dataset, str):
+ doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk)
+ logging.info(f"[Search query] {query_or_dataset}")
+
+ for i in range(topk):
+ logging.info(f"Top-{i+1} passage with score {doc_scores[i]:.6f}")
+ logging.info(self.contexts[doc_indices[i]])
+
+ return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)])
+
+ elif isinstance(query_or_dataset, Dataset):
+ total = []
+ with timer("query exhaustive search"):
+ doc_scores, doc_indices = self.get_relevant_doc_bulk(
+ query_or_dataset["question"], k=topk
+ )
+ for idx, example in enumerate(tqdm(query_or_dataset, desc="[Sparse retrieval] ")):
+ retrieved_contexts = [self.contexts[pid] for pid in doc_indices[idx]]
+ tmp = {
+ "question": example["question"],
+ "id": example["id"],
+ "context": " ".join(retrieved_contexts),
+ }
+ if "context" in example.keys() and "answers" in example.keys():
+ tmp["original_context"] = example["context"]
+ tmp["answers"] = example["answers"]
+ total.append(tmp)
+
+ cqas = pd.DataFrame(total)
+ return cqas
+
+ def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]:
+ with timer("transform"):
+ dense_qvec = self.get_dense_embedding(question=[query])
+
+ with timer("query ex search"):
+ result = self.get_cosine_score(dense_qvec, self.dense_embeds)
+ # result = self.get_similarity_score(dense_qvec, self.dense_embeds)
+ sorted_result = np.argsort(result.squeeze())[::-1]
+ doc_score = result.squeeze()[sorted_result].tolist()[:k]
+ doc_indices = sorted_result.tolist()[:k]
+ return doc_score, doc_indices
+
+ def get_relevant_doc_bulk(
+ self, queries: List[str], k: Optional[int] = 1
+ ) -> Tuple[List, List]:
+ dense_qvec = self.get_dense_embedding(question=queries)
+
+ result = self.get_cosine_score(dense_qvec, self.dense_embeds)
+ # result = self.get_similarity_score(dense_qvec, self.dense_embeds)
+ doc_scores = []
+ doc_indices = []
+ for i in range(result.shape[0]):
+ sorted_result = np.argsort(result[i, :])[::-1]
+ doc_scores.append(result[i, :][sorted_result].tolist()[:k])
+ doc_indices.append(sorted_result.tolist()[:k])
+ return doc_scores, doc_indices
+
+if __name__ == "__main__":
+ import argparse
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--dataset_name", default="../data/train_dataset", type=str)
+ parser.add_argument("--data_path", default="../data", type=str)
+ parser.add_argument("--context_path", default="wikipedia_documents.json", type=str)
+
+ args = parser.parse_args()
+ logging.info(args.__dict__)
+
+ org_dataset = load_from_disk(args.dataset_name)
+ full_ds = concatenate_datasets(
+ [
+ org_dataset["train"].flatten_indices(),
+ org_dataset["validation"].flatten_indices(),
+ ]
+ )
+ logging.info("*" * 40 + " query dataset " + "*" * 40)
+ logging.info(f"Full dataset: {full_ds}")
+
+ retriever = DenseRetrieval(
+ data_path=args.data_path,
+ context_path=args.context_path,
+ )
+
+ retriever.get_dense_embedding()
+
+ # query = "๋ํต๋ น์ ํฌํจํ ๋ฏธ๊ตญ์ ํ์ ๋ถ ๊ฒฌ์ ๊ถ์ ๊ฐ๋ ๊ตญ๊ฐ ๊ธฐ๊ด์?"
+ query = "์ ๋ น์ ์ด๋ ํ์ฑ์์ ์ง๊ตฌ๋ก ์๋๊ฐ?"
+
+ with timer("single query by exhaustive search"):
+ scores, contexts = retriever.retrieve(query, topk=5)
+
+ with timer("bulk query by exhaustive search"):
+ df = retriever.retrieve(full_ds, topk=1)
+ if "original_context" in df.columns:
+ df["correct"] = df["original_context"] == df["context"]
+ logging.info(f'correct retrieval result by exhaustive search: {df["correct"].sum() / len(df)}')
\ No newline at end of file
diff --git a/src/retrieval_SPLADE.py b/src/retrieval_SPLADE.py
new file mode 100644
index 0000000..a2fe9c7
--- /dev/null
+++ b/src/retrieval_SPLADE.py
@@ -0,0 +1,129 @@
+import json
+import os
+import time
+import logging
+from contextlib import contextmanager
+from typing import List, NoReturn, Optional, Tuple, Union
+
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_from_disk
+from tqdm.auto import tqdm
+
+import torch
+import torch.nn.functional as F
+from sparsembed import model, retrieve
+from transformers import AutoModelForMaskedLM, AutoTokenizer
+
+from utils import set_seed
+
+set_seed(42)
+logger = logging.getLogger(__name__)
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ logging.info(f"[{name}] done in {time.time() - t0:.3f} s")
+
+
+class SpldRetrieval:
+ def __init__(
+ self,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ corpus: Optional[pd.DataFrame] = None
+ ) -> NoReturn:
+ self.data_path = data_path
+ with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
+ wiki = json.load(f)
+
+ self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))
+ logging.info(f"Lengths of contexts : {len(self.contexts)}")
+ self.ids = list(range(len(self.contexts)))
+ self.df = None
+
+ self.model = model.Splade(
+ model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
+ tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
+ device=device
+ )
+ self.retriever = retrieve.SpladeRetriever(
+ key="id",
+ on=["text"],
+ model=self.model
+ )
+
+ def get_sparse_embedding(self, df=None) -> NoReturn:
+ if df is None:
+ self.df = [{'id': i, 'text': c} for i, c in zip(self.ids, self.contexts)]
+ else:
+ self.df = df
+
+ def retrieve(
+ self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
+ ) -> Union[Tuple[List[float], List[str]], pd.DataFrame]:
+ assert self.df is not None, "You should first execute `get_sparse_embedding()`"
+
+ # print(self.df.to_dict())
+ self.retriever = self.retriever.add(
+ documents=self.df,
+ batch_size=10,
+ k_tokens=256,
+ )
+
+ result = self.retriever(
+ query_or_dataset,
+ k_tokens=20, # Maximum number of activated tokens.
+ k=100, # Number of documents to retrieve.
+ batch_size=10
+ )
+
+ doc_scores = []
+ doc_indices = []
+ for i in range(len(result)):
+ doc_scores.append([dic['similarity'] for k, dic in enumerate(result[i]) if k < topk])
+ doc_indices.append([dic['id'] for k, dic in enumerate(result[i]) if k < topk])
+ for i in range(topk):
+ logging.info(f"Top-{i+1} passage with score {doc_scores[0][i]:4f}")
+ logging.info(self.contexts[doc_indices[0][i]])
+ return doc_scores, doc_indices
+
+
+if __name__ == "__main__":
+ import argparse
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--dataset_name", default="../data/train_dataset", type=str)
+ parser.add_argument("--data_path", default="../data", type=str)
+ parser.add_argument("--context_path", default="wikipedia_documents.json", type=str)
+
+ args = parser.parse_args()
+ logging.info(args.__dict__)
+
+ org_dataset = load_from_disk(args.dataset_name)
+ full_ds = concatenate_datasets(
+ [
+ org_dataset["train"].flatten_indices(),
+ org_dataset["validation"].flatten_indices(),
+ ]
+ )
+ logging.info("*" * 40 + " query dataset " + "*" * 40)
+ logging.info(f"Full dataset: {full_ds}")
+
+ retriever = SpldRetrieval(
+ data_path=args.data_path,
+ context_path=args.context_path,
+ )
+
+ retriever.get_sparse_embedding()
+
+ query = "๋ํต๋ น์ ํฌํจํ ๋ฏธ๊ตญ์ ํ์ ๋ถ ๊ฒฌ์ ๊ถ์ ๊ฐ๋ ๊ตญ๊ฐ ๊ธฐ๊ด์?"
+
+ with timer("single query by exhaustive search"):
+ doc_scores, doc_indices = retriever.retrieve(query, topk=1)
\ No newline at end of file
diff --git a/src/retrieval_hybridsearch.py b/src/retrieval_hybridsearch.py
new file mode 100644
index 0000000..02201ca
--- /dev/null
+++ b/src/retrieval_hybridsearch.py
@@ -0,0 +1,334 @@
+import json
+import os
+import pickle
+import time
+import torch
+import logging
+import scipy
+import scipy.sparse
+from contextlib import contextmanager
+from typing import List, Optional, Tuple, Union, NoReturn
+from tqdm.auto import tqdm
+
+import argparse
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_from_disk
+from torch.nn.functional import normalize
+
+from sklearn.feature_extraction.text import TfidfVectorizer
+from rank_bm25 import BM25Okapi, BM25Plus
+from transformers import AutoTokenizer, AutoModel
+from utils import set_seed
+
+set_seed(42)
+logger = logging.getLogger(__name__)
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ logging.info(f"[{name}] done in {time.time() - t0:.3f} s")
+
+def mean_pooling(model_output, attention_mask):
+ token_embeddings = model_output[0]
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+class HybridSearch:
+ def __init__(
+ self,
+ tokenize_fn,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ corpus: Optional[pd.DataFrame] = None
+ ) -> NoReturn:
+ self.data_path = data_path
+ with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
+ wiki = json.load(f)
+
+ self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))
+ logging.info(f"Lengths of contexts : {len(self.contexts)}")
+ self.ids = list(range(len(self.contexts)))
+
+ self.tokenize_fn=tokenize_fn
+
+ self.dense_model_name= 'intfloat/multilingual-e5-large-instruct' #'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' #'BM-K/KoSimCSE-roberta' #'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
+ self.dense_tokenize_fn = AutoTokenizer.from_pretrained(
+ self.dense_model_name
+ )
+ # self.sparse_embeder = TfidfVectorizer(
+ # tokenizer=self.tokenize_fn, ngram_range=(1, 2), max_features=50000,
+ # )
+ self.spares_embeder = None
+ self.dense_embeder = AutoModel.from_pretrained(
+ self.dense_model_name
+ )
+ self.sparse_embeds = None
+ self.dense_embeds = None
+
+ def get_sparse_embedding(self, question=None):
+ vectorizer_path = os.path.join(self.data_path, "BM25Plus_sparse_vectorizer.bin")
+ embeddings_path = os.path.join(self.data_path, "BM25Plus_sparse_embedding.bin")
+ # vectorizer_path = os.path.join(self.data_path, "sparse_vectorizer.bin")
+ # embeddings_path = os.path.join(self.data_path, "sparse_embedding.bin")
+
+ if question is None:
+ if os.path.isfile(vectorizer_path) and os.path.isfile(embeddings_path):
+ with open(vectorizer_path, "rb") as f:
+ self.sparse_embeder = pickle.load(f)
+ with open(embeddings_path, "rb") as f:
+ self.sparse_embeds = pickle.load(f)
+ print("Sparse vectorizer and embeddings loaded.")
+ else:
+ print("Fitting sparse vectorizer and building embeddings.")
+ self.sparse_embeder = BM25Plus([self.tokenize_fn(doc) for doc in self.contexts], k1=1.837782128608009, b=0.587622663072072, delta=1.1490)
+ # self.sparse_embeds = self.sparse_embeder.fit_transform(self.contexts)
+ with open(vectorizer_path, "wb") as f:
+ pickle.dump(self.sparse_embeder, f)
+ if self.dense_embeds is not None:
+ with open(embeddings_path, "wb") as f:
+ pickle.dump(self.sparse_embeds, f)
+ print("Sparse vectorizer and embeddings saved.")
+ else:
+ # self.sparse_embeder๊ฐ CountVectorizer, TfidfVectorizer ๋ฑ ๊ฐ์ฒด ์ผ ๋์๋ง ์ด ๋ถ๋ถ ์ฌ์ฉ
+ if not hasattr(self.sparse_embeder, 'vocabulary_'):
+ vectorizer_path = os.path.join(self.data_path, "sparse_vectorizer.bin")
+ if os.path.isfile(vectorizer_path):
+ with open(vectorizer_path, "rb") as f:
+ self.sparse_embeder = pickle.load(f)
+ print("Sparse vectorizer loaded for transforming the query.")
+ else:
+ raise ValueError("The Sparse vectorizer is not fitted. Please run get_sparse_embedding() first.")
+ return self.sparse_embeder.transform(question)
+
+
+ def get_dense_embedding(self, question=None):
+ if question is None:
+ model_n = self.dense_model_name.split('/')[1]
+ pickle_name = f"{model_n}_dense_embedding.bin"
+ emd_path = os.path.join(self.data_path, pickle_name)
+
+ if os.path.isfile(emd_path):
+ self.dense_embeds = torch.load(emd_path)
+ print("Dense embedding loaded.")
+ else:
+ print("Building passage dense embeddings in batches.")
+ self.dense_embeds = []
+ batch_size = 64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.dense_embeder.to(device)
+
+ for i in tqdm(range(0, len(self.contexts), batch_size), desc="Encoding passages"):
+ batch_contexts = self.contexts[i:i+batch_size]
+ encoded_input = self.dense_tokenize_fn(
+ batch_contexts, padding=True, truncation=True, return_tensors='pt'
+ ).to(device)
+ with torch.no_grad():
+ model_output = self.dense_embeder(**encoded_input)
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ sentence_embeddings = normalize(sentence_embeddings, p=2, dim=1)
+ self.dense_embeds.append(sentence_embeddings.cpu())
+ del encoded_input, model_output, sentence_embeddings
+ torch.cuda.empty_cache()
+
+ self.dense_embeds = torch.cat(self.dense_embeds, dim=0)
+ torch.save(self.dense_embeds, emd_path)
+ print("Dense embeddings saved.")
+ else:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.dense_embeder.to(device)
+ encoded_input = self.dense_tokenize_fn(
+ question, padding=True, truncation=True, return_tensors='pt'
+ ).to(device)
+ with torch.no_grad():
+ model_output = self.dense_embeder(**encoded_input)
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+ sentence_embeddings = normalize(sentence_embeddings, p=2, dim=1)
+ return sentence_embeddings.cpu()
+
+ def hybrid_scale(self, dense_score, sparse_score, alpha: float):
+ if alpha < 0 or alpha > 1:
+ raise ValueError("Alpha must be between 0 and 1")
+
+ if isinstance(dense_score, torch.Tensor):
+ dense_score = dense_score.detach().numpy()
+ if isinstance(sparse_score, torch.Tensor):
+ sparse_score = sparse_score.detach().numpy()
+
+ def min_max_normalize(score):
+ return (score - np.min(score)) / (np.max(score) - np.min(score))
+ def z_score_normalize(score):
+ return (score - np.mean(score)) / np.std(score)
+
+ # dense_score_normalized = min_max_normalize(dense_score)
+ # sparse_score_normalized = min_max_normalize(sparse_score)
+ # dense_score_normalized = z_score_normalize(dense_score)
+ # sparse_score_normalized = z_score_normalize(sparse_score)
+
+ # result = (1 - alpha) * dense_score_normalized + alpha * sparse_score_normalized
+ result = (1 - alpha) * dense_score + alpha * sparse_score
+ return result
+
+ def get_similarity_score(self, q_vec, c_vec):
+ # if isinstance(q_vec, scipy.sparse.spmatrix):
+ # q_vec = q_vec.toarray()
+ # if isinstance(c_vec, scipy.sparse.spmatrix):
+ # c_vec = c_vec.toarray()
+
+ # q_vec = torch.tensor(q_vec)
+ # c_vec = torch.tensor(c_vec)
+ # return q_vec.matmul(c_vec.T)
+
+ if isinstance(q_vec, scipy.sparse.spmatrix):
+ q_vec = q_vec.toarray()
+ if isinstance(c_vec, scipy.sparse.spmatrix):
+ c_vec = c_vec.toarray()
+
+ q_vec = torch.tensor(q_vec, dtype=torch.float32)
+ c_vec = torch.tensor(c_vec, dtype=torch.float32)
+
+ if q_vec.ndim == 1:
+ q_vec = q_vec.unsqueeze(0)
+ if c_vec.ndim == 1:
+ c_vec = c_vec.unsqueeze(0)
+
+ similarity_score = torch.matmul(q_vec, c_vec.T)
+
+ return similarity_score
+
+ def get_cosine_score(self, q_vec, c_vec):
+ q_vec = q_vec / q_vec.norm(dim=1, keepdim=True)
+ c_vec = c_vec / c_vec.norm(dim=1, keepdim=True)
+ return torch.mm(q_vec, c_vec.T)
+
+ def retrieve(self, query_or_dataset, topk: Optional[int] = 1, alpha: Optional[float] = 0.7):
+ # assert self.sparse_embeds is not None, "You should first execute `get_sparse_embedding()`"
+ assert self.dense_embeds is not None, "You should first execute `get_dense_embedding()`"
+
+ if isinstance(query_or_dataset, str):
+ doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, alpha, k=topk)
+ logging.info(f"[Search query] {query_or_dataset}")
+
+ for i in range(topk):
+ logging.info(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
+ logging.info(self.contexts[doc_indices[i]])
+
+ # return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)], doc_indices) # ์์๋ก doc_indices ๋ฐํ ๊ฐ์ผ๋ก ์ถ๊ฐํด๋ ์ํ์ ๋์ค์ ์ญ์ ํ ๊ฒ
+ return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)])
+
+ elif isinstance(query_or_dataset, Dataset):
+ total = []
+ with timer("query exhaustive search"):
+ doc_scores, doc_indices = self.get_relevant_doc_bulk(
+ query_or_dataset["question"], alpha, k=topk
+ )
+ for idx, example in enumerate(tqdm(query_or_dataset, desc="[Hybrid retrieval] ")):
+ tmp = {
+ "question": example["question"],
+ "id": example["id"],
+ "context": " ".join([self.contexts[pid] for pid in doc_indices[idx]]),
+ }
+ if "context" in example.keys() and "answers" in example.keys():
+ tmp["original_context"] = example["context"]
+ tmp["answers"] = example["answers"]
+ total.append(tmp)
+
+ cqas = pd.DataFrame(total)
+ return cqas
+
+ def get_relevant_doc(self, query: str, alpha: float, k: Optional[int] = 1) -> Tuple[List, List]:
+ with timer("transform"):
+ # sparse_qvec = self.get_sparse_embedding([query])
+ dense_qvec = self.get_dense_embedding([query])
+ # assert sparse_qvec.nnz != 0, "Error: query contains no words in vocab."
+
+ with timer("query ex search"):
+ tokenized_query = [self.tokenize_fn(query)]
+ sparse_score = np.array([self.sparse_embeder.get_scores(query) for query in tokenized_query])
+ # sparse_score = self.get_similarity_score(sparse_qvec, self.sparse_embeds)
+ # dense_score = self.get_cosine_score(dense_qvec, self.dense_embeds)
+ dense_score = self.get_similarity_score(dense_qvec, self.dense_embeds)
+ result = self.hybrid_scale(dense_score.numpy(), sparse_score, alpha)
+ sorted_result = np.argsort(result.squeeze())[::-1]
+ doc_score = result.squeeze()[sorted_result].tolist()[:k]
+ doc_indices = sorted_result.tolist()[:k]
+ return doc_score, doc_indices
+
+ def get_relevant_doc_bulk(
+ self, queries: List[str], alpha: float, k: Optional[int] = 1
+ ) -> Tuple[List, List]:
+ # sparse_qvec = self.get_sparse_embedding(queries)
+ dense_qvec = self.get_dense_embedding(queries)
+ # assert sparse_qvec.nnz != 0, "Error: query contains no words in vocab."
+
+ tokenized_queries = [self.tokenize_fn(query) for query in queries]
+ sparse_score = np.array([self.sparse_embeder.get_scores(query) for query in tokenized_queries])
+ # sparse_score = self.get_similarity_score(sparse_qvec, self.sparse_embeds)
+ # dense_score = self.get_cosine_score(dense_qvec, self.dense_embeds)
+ dense_score = self.get_similarity_score(dense_qvec, self.dense_embeds)
+ result = self.hybrid_scale(dense_score.numpy(), sparse_score, alpha)
+ doc_scores = []
+ doc_indices = []
+ for i in range(result.shape[0]):
+ sorted_result = np.argsort(result[i, :])[::-1]
+ doc_scores.append(result[i, :][sorted_result].tolist()[:k])
+ doc_indices.append(sorted_result.tolist()[:k])
+ return doc_scores, doc_indices
+
+
+if __name__ == "__main__":
+ import argparse
+ from transformers import AutoTokenizer
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S'
+ )
+
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--dataset_name", default="../data/train_dataset", type=str)
+ parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str)
+ parser.add_argument("--data_path", default="../data", type=str)
+ parser.add_argument("--context_path", default="wikipedia_documents.json", type=str)
+ parser.add_argument("--use_faiss", default=False, type=bool)
+
+ args = parser.parse_args()
+ logging.info(args.__dict__)
+
+ org_dataset = load_from_disk(args.dataset_name)
+ if 'train' in org_dataset and 'validation' in org_dataset:
+ full_ds = concatenate_datasets(
+ [
+ org_dataset["train"].flatten_indices(),
+ org_dataset["validation"].flatten_indices(),
+ ]
+ )
+ else:
+ full_ds = org_dataset
+ logging.info("*" * 40 + " query dataset " + "*" * 40)
+ logging.info(f"Full dataset: {full_ds}")
+
+ tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned")
+
+ retriever = HybridSearch(
+ tokenize_fn=tokenizer.tokenize,
+ data_path=args.data_path,
+ context_path=args.context_path,
+ )
+ retriever.get_dense_embedding()
+ retriever.get_sparse_embedding()
+
+ # query = "๋ํต๋ น์ ํฌํจํ ๋ฏธ๊ตญ์ ํ์ ๋ถ ๊ฒฌ์ ๊ถ์ ๊ฐ๋ ๊ตญ๊ฐ ๊ธฐ๊ด์?"
+ query = "์ ๋ น์ ์ด๋ ํ์ฑ์์ ์ง๊ตฌ๋ก ์๋๊ฐ?"
+
+ with timer("single query by exhaustive search using hybrid search"):
+ scores, contexts = retriever.retrieve(query, topk=20, alpha=0.0060115995634538455)
+
+ for i, context in enumerate(contexts):
+ print(f"Top-{i} ์ ๋ฌธ์์
๋๋ค. ")
+ print("---------------------------------------------")
+ print(context)
+
+
diff --git a/src/retrieval_tfidf.py b/src/retrieval_tfidf.py
new file mode 100644
index 0000000..ce95be1
--- /dev/null
+++ b/src/retrieval_tfidf.py
@@ -0,0 +1,283 @@
+import json
+import os
+import pickle
+import time
+import logging
+from contextlib import contextmanager
+from typing import List, NoReturn, Optional, Tuple, Union
+
+import faiss
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_from_disk
+from sklearn.feature_extraction.text import TfidfVectorizer
+from tqdm.auto import tqdm
+
+from utils import set_seed
+
+set_seed(42)
+logger = logging.getLogger(__name__)
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ logging.info(f"[{name}] done in {time.time() - t0:.3f} s")
+
+
+class TFIDFRetrieval:
+ def __init__(
+ self,
+ tokenize_fn,
+ data_path: Optional[str] = "../data/",
+ context_path: Optional[str] = "wikipedia_documents.json",
+ corpus: Optional[pd.DataFrame] = None
+ ) -> NoReturn:
+ set_seed(42)
+ self.data_path = data_path
+ with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
+ wiki = json.load(f)
+
+ self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))
+ logging.info(f"Lengths of contexts : {len(self.contexts)}")
+ self.ids = list(range(len(self.contexts)))
+
+ self.tfidfv = TfidfVectorizer(
+ tokenizer=tokenize_fn, ngram_range=(1, 2), max_features=50000,
+ )
+
+ self.p_embedding = None
+ self.indexer = None
+
+ def get_sparse_embedding(self) -> NoReturn:
+ emd_path = os.path.join(self.data_path, "sparse_embedding.bin")
+ tfidfv_path = os.path.join(self.data_path, "tfidv.bin")
+
+ if os.path.isfile(emd_path) and os.path.isfile(tfidfv_path):
+ with open(emd_path, "rb") as file:
+ self.p_embedding = pickle.load(file)
+ with open(tfidfv_path, "rb") as file:
+ self.tfidfv = pickle.load(file)
+ logging.info(f"Embedding pickle shape: {self.p_embedding.shape}")
+ else:
+ logging.info("Embedding pickle not found. Building passage embedding...")
+ self.p_embedding = self.tfidfv.fit_transform(self.contexts)
+ logging.info(f"Embedding built with shape: {self.p_embedding.shape}")
+ with open(emd_path, "wb") as file:
+ pickle.dump(self.p_embedding, file)
+ with open(tfidfv_path, "wb") as file:
+ pickle.dump(self.tfidfv, file)
+ logging.info("Embedding pickle saved.")
+
+ def build_faiss(self, num_clusters=64) -> NoReturn:
+ indexer_path = os.path.join(self.data_path, f"faiss_clusters{num_clusters}.index")
+ if os.path.isfile(indexer_path):
+ logging.info("Load Saved Faiss Indexer.")
+ self.indexer = faiss.read_index(indexer_path)
+
+ else:
+ p_emb = self.p_embedding.astype(np.float32).toarray()
+ emb_dim = p_emb.shape[-1]
+
+ num_clusters = num_clusters
+ quantizer = faiss.IndexFlatL2(emb_dim)
+
+ self.indexer = faiss.IndexIVFScalarQuantizer(
+ quantizer, quantizer.d, num_clusters, faiss.METRIC_L2
+ )
+ self.indexer.train(p_emb)
+ self.indexer.add(p_emb)
+ faiss.write_index(self.indexer, indexer_path)
+ logging.info("Faiss Indexer Saved.")
+
+ def retrieve(
+ self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
+ ) -> Union[Tuple[List, List], pd.DataFrame]:
+ assert self.p_embedding is not None, "You should first execute `get_sparse_embedding()`"
+
+ if isinstance(query_or_dataset, str):
+ doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk)
+ logging.info(f"[Search query] {query_or_dataset}")
+
+ for i in range(topk):
+ logging.info(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
+ logging.info(self.contexts[doc_indices[i]])
+
+ return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)])
+
+ elif isinstance(query_or_dataset, Dataset):
+ total = []
+ with timer("query exhaustive search"):
+ doc_scores, doc_indices = self.get_relevant_doc_bulk(
+ query_or_dataset["question"], k=topk
+ )
+ for idx, example in enumerate(tqdm(query_or_dataset, desc="[Sparse retrieval] ")):
+ tmp = {
+ "question": example["question"],
+ "id": example["id"],
+ "context": " ".join([self.contexts[pid] for pid in doc_indices[idx]]),
+ }
+ if "context" in example.keys() and "answers" in example.keys():
+ tmp["original_context"] = example["context"]
+ tmp["answers"] = example["answers"]
+ total.append(tmp)
+
+ cqas = pd.DataFrame(total)
+ return cqas
+
+ def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]:
+ with timer("transform"):
+ query_vec = self.tfidfv.transform([query])
+ assert np.sum(query_vec) != 0, "Error: query contains no words in vocab."
+
+ with timer("query ex search"):
+ result = query_vec * self.p_embedding.T
+ if not isinstance(result, np.ndarray):
+ result = result.toarray()
+
+ sorted_result = np.argsort(result.squeeze())[::-1]
+ doc_score = result.squeeze()[sorted_result].tolist()[:k]
+ doc_indices = sorted_result.tolist()[:k]
+ return doc_score, doc_indices
+
+ def get_relevant_doc_bulk(
+ self, queries: List, k: Optional[int] = 1
+ ) -> Tuple[List, List]:
+ query_vec = self.tfidfv.transform(queries)
+ assert np.sum(query_vec) != 0, "Error: query contains no words in vocab."
+
+ result = query_vec * self.p_embedding.T
+ if not isinstance(result, np.ndarray):
+ result = result.toarray()
+ doc_scores = []
+ doc_indices = []
+ for i in range(result.shape[0]):
+ sorted_result = np.argsort(result[i, :])[::-1]
+ doc_scores.append(result[i, :][sorted_result].tolist()[:k])
+ doc_indices.append(sorted_result.tolist()[:k])
+ return doc_scores, doc_indices
+
+ def retrieve_faiss(
+ self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
+ ) -> Union[Tuple[List, List], pd.DataFrame]:
+ assert self.indexer is not None, "You should first execute `build_faiss().`"
+
+ if isinstance(query_or_dataset, str):
+ doc_scores, doc_indices = self.get_relevant_doc_faiss(
+ query_or_dataset, k=topk
+ )
+ logging.info(f"[Search query] {query_or_dataset}")
+
+ for i in range(topk):
+ logging.info(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
+ logging.info(self.contexts[doc_indices[i]])
+
+ return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)])
+
+ elif isinstance(query_or_dataset, Dataset):
+ queries = query_or_dataset["question"]
+ total = []
+
+ with timer("query faiss search"):
+ doc_scores, doc_indices = self.get_relevant_doc_bulk_faiss(
+ queries, k=topk
+ )
+ for idx, example in enumerate(
+ tqdm(query_or_dataset, desc="Sparse retrieval: ")
+ ):
+ tmp = {
+ "question": example["question"],
+ "id": example["id"],
+ "context": " ".join(
+ [self.contexts[pid] for pid in doc_indices[idx]]
+ ),
+ }
+ if "context" in example.keys() and "answers" in example.keys():
+ tmp["original_context"] = example["context"]
+ tmp["answers"] = example["answers"]
+ total.append(tmp)
+
+ return pd.DataFrame(total)
+
+ def get_relevant_doc_faiss(
+ self, query: str, k: Optional[int] = 1
+ ) -> Tuple[List, List]:
+ query_vec = self.tfidfv.transform([query])
+ assert np.sum(query_vec) != 0, "Error: query contains no words in vocab."
+
+ q_emb = query_vec.toarray().astype(np.float32)
+ with timer("query faiss search"):
+ D, I = self.indexer.search(q_emb, k)
+
+ return D.tolist()[0], I.tolist()[0]
+
+ def get_relevant_doc_bulk_faiss(
+ self, queries: List, k: Optional[int] = 1
+ ) -> Tuple[List, List]:
+ query_vecs = self.tfidfv.transform(queries)
+ assert np.sum(query_vecs) != 0, "Error: query contains no words in vocab."
+
+ q_embs = query_vecs.toarray().astype(np.float32)
+ D, I = self.indexer.search(q_embs, k)
+
+ return D.tolist(), I.tolist()
+
+
+if __name__ == "__main__":
+ import argparse
+ from transformers import AutoTokenizer
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ parser = argparse.ArgumentParser(description="")
+ parser.add_argument("--dataset_name", default="../data/train_dataset", type=str)
+ parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str)
+ parser.add_argument("--data_path", default="../data", type=str)
+ parser.add_argument("--context_path", default="wikipedia_documents.json", type=str)
+ parser.add_argument("--use_faiss", default=False, type=bool)
+
+ args = parser.parse_args()
+ logging.info(args.__dict__)
+
+ org_dataset = load_from_disk(args.dataset_name)
+ full_ds = concatenate_datasets(
+ [
+ org_dataset["train"].flatten_indices(),
+ org_dataset["validation"].flatten_indices(),
+ ]
+ )
+ logging.info("*" * 40, "query dataset", "*" * 40)
+ logging.info(f"Full dataset: {full_ds}")
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False,)
+
+ retriever = TFIDFRetrieval(
+ tokenize_fn=tokenizer.tokenize,
+ data_path=args.data_path,
+ context_path=args.context_path,
+ )
+
+ # query = "๋ํต๋ น์ ํฌํจํ ๋ฏธ๊ตญ์ ํ์ ๋ถ ๊ฒฌ์ ๊ถ์ ๊ฐ๋ ๊ตญ๊ฐ ๊ธฐ๊ด์?"
+ query = "์ ๋ น์ ์ด๋ ํ์ฑ์์ ์ง๊ตฌ๋ก ์๋๊ฐ?"
+
+ if args.use_faiss:
+ with timer("single query by faiss"):
+ scores, indices = retriever.retrieve_faiss(query)
+
+ with timer("bulk query by exhaustive search"):
+ df = retriever.retrieve_faiss(full_ds)
+ df["correct"] = df["original_context"] == df["context"]
+
+ logging.info(f'correct retrieval result by faiss {df["correct"].sum() / len(df)}')
+
+ else:
+ retriever.get_sparse_embedding()
+ with timer("bulk query by exhaustive search"):
+ df = retriever.retrieve(full_ds)
+ df["correct"] = df["original_context"] == df["context"]
+ logging.info(f'correct retrieval result by exhaustive search: {df["correct"].sum() / len(df)}')
+
+ with timer("single query by exhaustive search"):
+ scores, indices = retriever.retrieve(query)
\ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000..d6a586d
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,252 @@
+import collections
+import json
+import logging
+import os
+import random
+from typing import Any, Optional, Tuple
+
+import numpy as np
+import torch
+from arguments import DataTrainingArguments
+from datasets import DatasetDict
+from tqdm.auto import tqdm
+from transformers import PreTrainedTokenizerFast, TrainingArguments
+from transformers.trainer_utils import get_last_checkpoint
+
+logger = logging.getLogger(__name__)
+
+def set_seed(seed: int = 42):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.use_deterministic_algorithms(True)
+
+
+def postprocess_qa_predictions(
+ examples,
+ features,
+ predictions: Tuple[np.ndarray, np.ndarray],
+ version_2_with_negative: bool = False,
+ n_best_size: int = 20,
+ max_answer_length: int = 30,
+ null_score_diff_threshold: float = 0.0,
+ output_dir: Optional[str] = None,
+ prefix: Optional[str] = None,
+ is_world_process_zero: bool = True,
+):
+ assert (len(predictions) == 2), "`predictions` should be a tuple with two elements (start_logits, end_logits)."
+ all_start_logits, all_end_logits = predictions
+
+ assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
+
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
+ features_per_example = collections.defaultdict(list)
+ for i, feature in enumerate(features):
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ if version_2_with_negative:
+ scores_diff_json = collections.OrderedDict()
+
+ logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
+ logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
+
+ for example_index, example in enumerate(tqdm(examples)):
+ feature_indices = features_per_example[example_index]
+
+ min_null_prediction = None
+ prelim_predictions = []
+
+ for feature_index in feature_indices:
+ start_logits = all_start_logits[feature_index]
+ end_logits = all_end_logits[feature_index]
+ offset_mapping = features[feature_index]["offset_mapping"]
+ # `token_is_max_context`, # Optional: delete the answer that is not available in the current function
+ token_is_max_context = features[feature_index].get("token_is_max_context", None)
+
+ feature_null_score = start_logits[0] + end_logits[0]
+ if (
+ min_null_prediction is None
+ or min_null_prediction["score"] > feature_null_score
+ ):
+ min_null_prediction = {
+ "offsets": (0, 0),
+ "score": feature_null_score,
+ "start_logit": start_logits[0],
+ "end_logit": end_logits[0],
+ }
+
+ start_indexes = np.argsort(start_logits)[-1:-n_best_size - 1:-1].tolist()
+ end_indexes = np.argsort(end_logits)[-1:-n_best_size - 1:-1].tolist()
+
+ for start_index in start_indexes:
+ for end_index in end_indexes:
+ if (
+ start_index >= len(offset_mapping)
+ or end_index >= len(offset_mapping)
+ or offset_mapping[start_index] is None
+ or offset_mapping[end_index] is None
+ ):
+ continue
+ if (
+ end_index < start_index
+ or end_index - start_index + 1 > max_answer_length
+ ):
+ continue
+ if (
+ token_is_max_context is not None
+ and not token_is_max_context.get(str(start_index), False)
+ ):
+ continue
+ prelim_predictions.append(
+ {
+ "offsets": (
+ offset_mapping[start_index][0],
+ offset_mapping[end_index][1],
+ ),
+ "score": start_logits[start_index] + end_logits[end_index],
+ "start_logit": start_logits[start_index],
+ "end_logit": end_logits[end_index],
+ }
+ )
+
+ if version_2_with_negative:
+ prelim_predictions.append(min_null_prediction)
+ null_score = min_null_prediction["score"]
+
+ predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
+
+ if version_2_with_negative and not any(
+ p["offsets"] == (0, 0) for p in predictions
+ ):
+ predictions.append(min_null_prediction)
+
+ context = example["context"]
+ for pred in predictions:
+ offsets = pred.pop("offsets")
+ pred["text"] = context[offsets[0] : offsets[1]]
+
+ if len(predictions) == 0 or (
+ len(predictions) == 1 and predictions[0]["text"] == ""
+ ):
+
+ predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
+
+ scores = np.array([pred.pop("score") for pred in predictions])
+ exp_scores = np.exp(scores - np.max(scores))
+ probs = exp_scores / exp_scores.sum()
+
+ for prob, pred in zip(probs, predictions):
+ pred["probability"] = prob
+
+ if not version_2_with_negative:
+ all_predictions[example["id"]] = predictions[0]["text"]
+ else:
+ i = 0
+ while predictions[i]["text"] == "":
+ i += 1
+ best_non_null_pred = predictions[i]
+
+ score_diff = (
+ null_score
+ - best_non_null_pred["start_logit"]
+ - best_non_null_pred["end_logit"]
+ )
+ scores_diff_json[example["id"]] = float(score_diff)
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example["id"]] = ""
+ else:
+ all_predictions[example["id"]] = best_non_null_pred["text"]
+
+ all_nbest_json[example["id"]] = [
+ {
+ k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()
+ }
+ for pred in predictions
+ ]
+
+ if output_dir is not None:
+ assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
+
+ prediction_file = os.path.join(
+ output_dir,
+ "predictions.json" if prefix is None else f"predictions_{prefix}".json,
+ )
+ nbest_file = os.path.join(
+ output_dir,
+ "nbest_predictions.json"
+ if prefix is None
+ else f"nbest_predictions_{prefix}".json,
+ )
+ if version_2_with_negative:
+ null_odds_file = os.path.join(
+ output_dir,
+ "null_odds.json" if prefix is None else f"null_odds_{prefix}".json,
+ )
+
+ logger.info(f"Saving predictions to {prediction_file}.")
+ with open(prediction_file, "w", encoding="utf-8") as writer:
+ writer.write(
+ json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n"
+ )
+ logger.info(f"Saving nbest_preds to {nbest_file}.")
+ with open(nbest_file, "w", encoding="utf-8") as writer:
+ writer.write(
+ json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + "\n"
+ )
+ if version_2_with_negative:
+ logger.info(f"Saving null_odds to {null_odds_file}.")
+ with open(null_odds_file, "w", encoding="utf-8") as writer:
+ writer.write(
+ json.dumps(scores_diff_json, indent=4, ensure_ascii=False) + "\n"
+ )
+
+ return all_predictions
+
+
+def check_no_error(
+ data_args: DataTrainingArguments,
+ training_args: TrainingArguments,
+ datasets: DatasetDict,
+ tokenizer,
+) -> Tuple[Any, int]:
+ last_checkpoint = None
+ if (
+ os.path.isdir(training_args.output_dir)
+ and training_args.do_train
+ and not training_args.overwrite_output_dir
+ ):
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ if not isinstance(tokenizer, PreTrainedTokenizerFast):
+ raise ValueError(
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
+ "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
+ "requirement"
+ )
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ if "validation" not in datasets:
+ raise ValueError("--do_eval requires a validation dataset")
+ return last_checkpoint, max_seq_length