diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7c8be28 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# python +__pycache__ +.idea + +# dataset +data +code +EDA + +# outputs +models +output + +# src +src/test +src/wandb + +# wandb +wandb \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..69610b3 --- /dev/null +++ b/README.md @@ -0,0 +1,163 @@ +
+ + # ๐Ÿ† LV.2 NLP ํ”„๋กœ์ ํŠธ : Open-Domain Question Answering + +
+ +## โœ๏ธ ๋Œ€ํšŒ ์†Œ๊ฐœ + +| ํŠน์ง• | ์„ค๋ช… | +|:------:| --- | +| ๋Œ€ํšŒ ์ฃผ์ œ | ๋„ค์ด๋ฒ„ ๋ถ€์ŠคํŠธ์บ ํ”„ AI Tech 7๊ธฐ NLP Track์˜ Level 2 ๋„๋ฉ”์ธ ๊ธฐ์ดˆ ๋Œ€ํšŒ 'Open-Domain Question Answering (Machine Reading Comprehension)'์ž…๋‹ˆ๋‹ค. | +| ๋Œ€ํšŒ ์„ค๋ช… | ์ฃผ์–ด์ง€๋Š” Documents์˜ ๋‚ด์šฉ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์ด ์ฃผ์–ด์ง€๋ฉด, ๊ทธ ์งˆ๋ฌธ์— ๋Œ€ํ•œ ์ •ํ™•ํ•œ ๋‹ต๋ณ€์„ ๋ฌธ์„œ์—์„œ ์ฐพ์•„๋‚ด๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•ฉ๋‹ˆ๋‹ค. | +| ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ | ๋ฐ์ดํ„ฐ๋Š” ์œ„ํ‚คํ”ผ๋””์•„์˜ ๋‚ด์šฉ์œผ๋กœ ๋Œ€๋ถ€๋ถ„ ์ด๋ฃจ์–ด์ง„ ๋ฌธ์„œ ๋ฐ์ดํ„ฐ, ๊ทธ๋ฆฌ๊ณ  Question๊ณผ Answer๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. | +| ํ‰๊ฐ€ ์ง€ํ‘œ | ๋‹ต๋ณ€์„ ์ •ํ™•ํžˆ ์ถ”์ถœํ•˜๋Š”์ง€๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด EM(Exact Match) ์ง€ํ‘œ๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.| + + +## ๐ŸŽ–๏ธ Leader Board +ํ”„๋กœ์ ํŠธ ๊ฒฐ๊ณผ Public ๋ฆฌ๋”๋ณด๋“œ 2๋“ฑ, Private ๋ฆฌ๋”๋ณด๋“œ 2๋“ฑ์„ ๊ธฐ๋กํ•˜์˜€์Šต๋‹ˆ๋‹ค. +### ๐Ÿฅˆ Public Leader Board (2์œ„) +![leaderboard_mid](./docs/leaderboard_mid.png) + +### ๐Ÿฅˆ Private Leader Board (2์œ„) +![leaderboard_final](./docs/leaderboard_final.png) + +## ๐Ÿ‘จโ€๐Ÿ’ป 15์กฐ๊ฐ€์‹ญ์˜ค์กฐ ๋ฉค๋ฒ„ +
+ +| ๊น€์ง„์žฌ [](https://github.com/jin-jae) | ๋ฐ•๊ทœํƒœ [](https://github.com/doraemon500) | ์œค์„ ์›… [](https://github.com/ssunbear) | ์ด์ •๋ฏผ [](https://github.com/simigami) | ์ž„ํ•œํƒ [](https://github.com/LHANTAEK) +|:-:|:-:|:-:|:-:|:-:| +| ![๊น€์ง„์žฌ](https://avatars.githubusercontent.com/u/97018331) | ![๋ฐ•๊ทœํƒœ](https://avatars.githubusercontent.com/u/64678476) | ![์œค์„ ์›…](https://avatars.githubusercontent.com/u/117508164) | ![์ด์ •๋ฏผ](https://avatars.githubusercontent.com/u/46891822) | ![์ž„ํ•œํƒ](https://avatars.githubusercontent.com/u/143519383) | + +
+ + +## ๐Ÿ‘ผ ์—ญํ•  ๋ถ„๋‹ด +
+ +|ํŒ€์› | ์—ญํ•  | +|------| --- | +| ๊น€์ง„์žฌ | (ํŒ€์žฅ) ๋ฒ ์ด์Šค๋ผ์ธ ์ฝ”๋“œ ์ž‘์„ฑ ๋ฐ ๊ฐœ์„ , ํ”„๋กœ์ ํŠธ ๋งค๋‹ˆ์ง• ๋ฐ ํ™˜๊ฒฝ ๊ด€๋ฆฌ, ์กฐ์‚ฌ ์ „์ฒ˜๋ฆฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ฐœ๋ฐœ, ์ƒˆ๋กœ์šด ์ ‘๊ทผ ๋ฐฉ๋ฒ•๋ก  ์ œ์•ˆ, ์•™์ƒ๋ธ” | +| ๋ฐ•๊ทœํƒœ | ๋ฐ์ดํ„ฐ ํŠน์„ฑ ๋ถ„์„, EDA, Retrieval ๊ตฌํ˜„, ๋น„๊ต ์‹คํ—˜ ๋ฐ ๊ฐœ์„ (ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ์„œ์น˜. Re-ranking, Dense, SPLADE ๋“ฑ๋“ฑ), Reader ๋ชจ๋ธ ํŒŒ์ธํŠœ๋‹ | +| ์œค์„ ์›… | KorQuAD 1.0 ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•, ๋ชจ๋ธ ํŒŒ์ธํŠœ๋‹, Reader ๋ชจ๋ธ ๊ฐœ์„ (CNN layer ์ถ”๊ฐ€), Retrieval ๋ชจ๋ธ ๊ตฌํ˜„(BM25), ์•™์ƒ๋ธ” | +| ์ด์ •๋ฏผ | ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• (AEDA, Truncation ๋“ฑ), Question ๋ฐ์ดํ„ฐ์…‹ ํŠœ๋‹, Korquad ๋ฐ์ดํ„ฐ์…‹ ํŠœ๋‹ | +| ์ž„ํ•œํƒ | EDA, Retrieval ๋ชจ๋ธ ๊ฐœ์„ (BM25Plus, Re-ranking ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ตœ์ ํ™”), Reader ๋ชจ๋ธ ๊ฐœ์„ (PLM ์„ ์ • ๋ฐ Trainer ํŒŒ๋ผ๋ฏธํ„ฐ ์ตœ์ ํ™”) | + +
+ + +## ๐Ÿƒ ํ”„๋กœ์ ํŠธ +### ๐Ÿ–ฅ๏ธ ํ”„๋กœ์ ํŠธ ๊ฐœ์š” +|๊ฐœ์š”| ์„ค๋ช… | +|:------:| --- | +| ์ฃผ์ œ | ๊ธฐ๊ณ„ ๋…ํ•ด MRC (Machine Reading Comprehension) ์ค‘ โ€˜Open-Domain Question Answeringโ€™ ๋ฅผ ์ฃผ์ œ๋กœ, ์ฃผ์–ด์ง„ ์งˆ์˜์™€ ๊ด€๋ จ๋œ ๋ฌธ์„œ๋ฅผ ํƒ์ƒ‰ํ•˜๊ณ , ํ•ด๋‹น ๋ฌธ์„œ์—์„œ ์ ์ ˆํ•œ ๋‹ต๋ณ€์„ ์ฐพ๊ฑฐ๋‚˜ ์ƒ์„ฑํ•˜๋Š” task๋ฅผ ์ˆ˜ํ–‰ | +| ๊ตฌ์กฐ | Retrieval ๋‹จ๊ณ„์™€ Reader ๋‹จ๊ณ„์˜ two-stage ๊ตฌ์กฐ ์‚ฌ์šฉ | +| ํ‰๊ฐ€ ์ง€ํ‘œ | ํ‰๊ฐ€ ์ง€ํ‘œ๋กœ๋Š” EM Score(Exact Match Score)์ด ์‚ฌ์šฉ๋˜์—ˆ๊ณ , ๋ชจ๋ธ์ด ์˜ˆ์ธกํ•œ text์™€ ์ •๋‹ต text๊ฐ€ ๊ธ€์ž ๋‹จ์œ„๋กœ ์™„์ „ํžˆ ๋˜‘๊ฐ™์€ ๊ฒฝ์šฐ์—๋งŒ ์ ์ˆ˜ ๋ถ€์—ฌ | +| ๊ฐœ๋ฐœ ํ™˜๊ฒฝ | `GPU` : Tesla V100 Server 4๋Œ€, `IDE` : Vscode, Jupyter Notebook | +| ํ˜‘์—… ํ™˜๊ฒฝ | Notion(์ง„ํ–‰ ์ƒํ™ฉ ๊ณต์œ ), Github(์ฝ”๋“œ ๋ฐ ๋ฐ์ดํ„ฐ ๊ณต์œ ), Slack(์‹ค์‹œ๊ฐ„ ์†Œํ†ต), W&B(์‹œ๊ฐํ™”, ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹) | + +### ๐Ÿ“… ํ”„๋กœ์ ํŠธ ํƒ€์ž„๋ผ์ธ +- ํ”„๋กœ์ ํŠธ๋Š” 2024-09-30 ~ 2024-10-25๊นŒ์ง€ ์ง„ํ–‰๋˜์—ˆ์Šต๋‹ˆ๋‹ค. +![ํƒ€์ž„๋ผ์ธ](./docs/ํƒ€์ž„๋ผ์ธ.png) + +### ๐Ÿ•ต๏ธ ํ”„๋กœ์ ํŠธ ์ง„ํ–‰ +- ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ํ•˜๋ฉฐ ๋‹จ๊ณ„๋ณ„๋กœ ์‹คํ—˜ํ•˜์—ฌ ์ ์šฉํ•œ ๋‚ด์šฉ๋“ค์„ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. + + +| ํ”„๋กœ์„ธ์Šค | ์„ค๋ช… | +|:-----------------:| --- | +| ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ | AEDA, Swap Sentence, Truncation, Mecab์„ ํ™œ์šฉํ•œ Question ๊ฐ•์กฐ, LLM๊ธฐ๋ฐ˜ ์กฐ์‚ฌ์ œ๊ฑฐ | +| ๋ชจ๋ธ Finetuning | Korquad dataset ์ถ”๊ฐ€, Korquad1 PLM์— Korquad2 ๋ฐ์ดํ„ฐ์…‹ fine-tuning | +| Retriever ๋ชจ๋ธ ๊ฐœ์„  | BM25Plus, DPR, Hybrid Search, Re-rank(2-stage) | +| Reader ๋ชจ๋ธ ๊ฐœ์„  | CNN Layer ์ถ”๊ฐ€, Head Customizing, Dropout, Learning rate ํŠœ๋‹ | +| ์•™์ƒ๋ธ” ๋ฐฉ๋ฒ• | Soft Voting: nbest_predictions.json์—์„œ ์ œ๊ณตํ•˜๋Š” ๋‹จ์–ด๋ณ„ ํ™•๋ฅ ๊ฐ’์„ ํ™œ์šฉํ•ด์„œ, ๊ฐ ํŒŒ์ผ์—์„œ ๋‹จ์–ด์˜ ํ™•๋ฅ ๊ฐ’์„ ํ‰๊ท ๋‚ธ ํ›„ ๊ฐ€์žฅ ๋†’์€ ๊ฐ’์„ ์„ ํƒํ•˜๋Š” ๋ฐฉ์‹ | + + +### ๐Ÿค– Ensemble +| ๋ฒˆํ˜ธ | ๋ชจ๋ธ+๊ธฐ๋ฒ• | EM(Public) | +|------|--------------------------------------------------------|------------| +| 1 | uomnf97+BM25+CNN | 66.67 | +| 7 | Curtis+CNN+dropout(only_FC_0.05)+BM25Plus | 66.25 | +| 8 | Curtis+Truncation | 66.25 | +| 9 | HANTAEK_hybrid_optuna_topk20(k1=1.84) | 63.15 | +| 10 | HANTAEK_hybrid_optuna_topk20(k1=0.73) | 63.75 | +| 11 | HANTAEK_hybrid_optuna_topk10(k1=0.73) | 63.75 | +| 12 | uomnf97+BM25 | 67.08 | +| 13 | uomnf97+CNN+Re_rank500_20+Cosine | 67.08 | +| 14 | curtis+CNN+Re_rank_500_20 | 65.42 | +| 15 | nlp04_finetuned+CNN+BM25Plus+epoch1_predictions | 67.5 | + +### ๐Ÿ“ƒ Results +| ์ตœ์ข…์ œ์ถœ | Ensemble | EM(Public) | EM(Private) | +|----------|----------------------------------------------------|------------|--------------| +| O | ๋ชจ๋ธ 7,8,9,10,11,12,13 1:1:1:1:1:2:3 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | **77.08** | 71.11 | +| O | ๋ชจ๋ธ 14,8,15,10,11,12,13 1:1:1:1:2:3 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | **77.08** | 71.67 | +| | ๋ชจ๋ธ 1,7,8,9,10,11,12 ํ‰๊ท ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | 76.67 | 71.67 | +| | ๋ชจ๋ธ 7,8,9,10,11,12,13 1:1:1:2:2:2 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | 76.67 | 70.83 | +| 1st SOTA | ๋ชจ๋ธ 15,9,10,11,12,13 1:1:1:1:3:3 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | 75.42 | **74.17** | +| | ๋ชจ๋ธ 7,8,9,10,11,12,13 1:1:1:1:2:3 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM(n=5) | 75.42 | 71.67 | +| | ๋ชจ๋ธ 7,8,9,10,11,12,13,14 1:1:1:1:2:2 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | 74.58 | 71.11 | +| 2nd SOTA | ๋ชจ๋ธ 14,8,9,10,11,12,13 1:1:1:1:2:3 ์•™์ƒ๋ธ” + ์กฐ์‚ฌ LLM | 74.58 | **72.22** | + + + +## ๐Ÿ“ ํ”„๋กœ์ ํŠธ ๊ตฌ์กฐ +ํ”„๋กœ์ ํŠธ ํด๋” ๊ตฌ์กฐ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. +``` +level2-mrc-nlp-15 +โ”œโ”€โ”€ data +โ”‚ โ”œโ”€โ”€ test_dataset +โ”‚ โ”œโ”€โ”€ train_dataset +โ”‚ โ””โ”€โ”€ wikipedia_documents.json +โ”œโ”€โ”€ docs +โ”‚ โ”œโ”€โ”€ github_official_logo.png +โ”‚ โ”œโ”€โ”€ leaderboard_final.png +โ”‚ โ””โ”€โ”€ leaderboard_mid.png +โ”œโ”€โ”€ models +โ”œโ”€โ”€ output +โ”œโ”€โ”€ README.md +โ”œโ”€โ”€ requirements.txt +โ”œโ”€โ”€ run.py +โ””โ”€โ”€ src + โ”œโ”€โ”€ arguments.py + โ”œโ”€โ”€ CNN_layer_model.py + โ”œโ”€โ”€ data_analysis.py + โ”œโ”€โ”€ ensemble + โ”‚ โ”œโ”€โ”€ probs_voting_ensemble_n.py + โ”‚ โ”œโ”€โ”€ probs_voting_ensemble.py + โ”‚ โ””โ”€โ”€ scores_voting_ensemble.py + โ”œโ”€โ”€ korquad_finetuning_v2.ipynb + โ”œโ”€โ”€ main.py + โ”œโ”€โ”€ optimize_retriever.py + โ”œโ”€โ”€ preprocess_answer.ipynb + โ”œโ”€โ”€ qa_trainer.py + โ”œโ”€โ”€ retrieval_2s_rerank.py + โ”œโ”€โ”€ retrieval_BM25.py + โ”œโ”€โ”€ retrieval_Dense.py + โ”œโ”€โ”€ retrieval_hybridsearch.py + โ”œโ”€โ”€ retrieval.py + โ”œโ”€โ”€ retrieval_SPLADE.py + โ”œโ”€โ”€ retrieval_tfidf.py + โ”œโ”€โ”€ utils.py + โ””โ”€โ”€ wandb +``` + +### ๐Ÿ“ฆ src ํด๋” ๊ตฌ์กฐ ์„ค๋ช… +- arguments.py : ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•์„ ํ•˜๋Š” ํŒŒ์ผ +- CNN_layer_model.py : PLM์— CNN Layer๋ฅผ ์ถ”๊ฐ€ํ•œ ํด๋ž˜์Šค ํŒŒ์ผ +- data_analysis.py : ๋ฐ์ดํ„ฐ์…‹์„ ๋ถ„์„ํ•˜๋Š” ํŒŒ์ผ +- ensemble : ๋ชจ๋ธ ์•™์ƒ๋ธ”์„ ํ•˜๋Š” ํด๋” (Soft, Hard ์ง€์›) +- main.py : ๋ชจ๋ธ train, eval, prediction ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํŒŒ์ผ +- optimize_retriever.py : ๋ฆฌํŠธ๋ฆฌ๋ฒ„์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ตœ์ ํ™” ํ•˜๋Š” ํŒŒ์ผ +- qa_trainer.py : MRC Task์— ๋Œ€ํ•œ ์ปค์Šคํ…€ Trainer ํด๋ž˜์Šค ํŒŒ์ผ +- retrieval_2s_rerank.py : rerank ๋ฆฌํŠธ๋ฆฌ๋ฒ„ ํŒŒ์ผ +- retrieval_BM25.py : bm25 ๋ฆฌํŠธ๋ฆฌ๋ฒ„ ํŒŒ์ผ +- retrieval_Dense.py : DPR ๋ฆฌํŠธ๋ฆฌ๋ฒ„ ํŒŒ์ผ +- retrieval_hybridsearch.py : hybrid-search ๋ฆฌํŠธ๋ฆฌ๋ฒ„ ํŒŒ์ผ +- retrieval_SPLADE.py : SPLADE ๋ฆฌํŠธ๋ฆฌ๋ฒ„ ํŒŒ์ผ +- retrieval_tfidf.py : TF-IDF ๋ฆฌํŠธ๋ฆฌ๋ฒ„ ํŒŒ์ผ + + +### ๐Ÿ’พ Installation +- `python=3.10` ํ™˜๊ฒฝ์—์„œ requirements.txt๋ฅผ pip๋กœ install ํ•ฉ๋‹ˆ๋‹ค. (```pip install -r requirements.txt```) +- `python run.py`๋ฅผ ์ž…๋ ฅํ•˜์—ฌ ํ”„๋กœ๊ทธ๋žจ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. diff --git a/[NLP-15] Wrap-Up Report of ODQA.pdf b/[NLP-15] Wrap-Up Report of ODQA.pdf new file mode 100644 index 0000000..a83c5db Binary files /dev/null and b/[NLP-15] Wrap-Up Report of ODQA.pdf differ diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/github_official_logo.png b/docs/github_official_logo.png new file mode 100644 index 0000000..6cb3b70 Binary files /dev/null and b/docs/github_official_logo.png differ diff --git a/docs/leaderboard_final.png b/docs/leaderboard_final.png new file mode 100644 index 0000000..327d48a Binary files /dev/null and b/docs/leaderboard_final.png differ diff --git a/docs/leaderboard_mid.png b/docs/leaderboard_mid.png new file mode 100644 index 0000000..439d138 Binary files /dev/null and b/docs/leaderboard_mid.png differ diff --git "a/docs/\355\203\200\354\236\204\353\235\274\354\235\270.png" "b/docs/\355\203\200\354\236\204\353\235\274\354\235\270.png" new file mode 100644 index 0000000..81e5054 Binary files /dev/null and "b/docs/\355\203\200\354\236\204\353\235\274\354\235\270.png" differ diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/output/.gitkeep b/output/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..367c563 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +datasets==2.15.0 +faiss-gpu==1.7.2 +networkx==3.1 +rank-bm25==0.2.2 +scikit-learn==1.4.0 +torchaudio==2.1.0 +torchvision==0.16.0 +nltk==3.9.1 +sentence_transformers==2.2.2 +sentencepiece==0.2.0 +tokenizers==0.13.0 +huggingface_hub==0.24.7 +ipykernel==6.29.5 +scipy==1.7.3 +torch==2.1.0 +transformers==4.25.1 +wandb==0.18.3 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..d4261b1 --- /dev/null +++ b/run.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import os +import subprocess +from datetime import datetime, timedelta + +# Get current time (UTC + 9 hours) +current_time = datetime.utcnow() + timedelta(hours=9) +current_time_str = current_time.strftime('%Y%m%d_%H%M%S') + +# Root directory (adjust this if necessary) +root_dir = os.getcwd() +#root_dir = os.path.join(os.sep, 'data', 'ephemeral', 'home', 'level2-mrc-nlp-15') + +# Ensure root directory exists +if not os.path.exists(root_dir): + raise FileNotFoundError(f"The root directory {root_dir} does not exist. Please adjust the path accordingly.") + +# Set up directories +train_dir = os.path.join(root_dir, 'models', f'train_{current_time_str}') +predict_dir = os.path.join(root_dir, 'output', f'test_{current_time_str}') +predict_dataset_name = os.path.join(root_dir, 'data', 'test_dataset') + +# Change to src directory +src_dir = os.path.join(root_dir, 'src') +if not os.path.exists(src_dir): + raise FileNotFoundError(f"The source directory {src_dir} does not exist. Please adjust the path accordingly.") +os.chdir(src_dir) + +# Perform training +subprocess.run([ + "python", "main.py", + "--output_dir", train_dir, + "--do_train", + "--overwrite_output_dir", + "--per_device_train_batch_size", "16", + "--learning_rate", "1e-5", + "--num_train_epochs", "3" +], check=True) + +# Perform evaluation (optional) +eval_dir = os.path.join(root_dir, 'output', f'train_dataset_{current_time_str}') +subprocess.run([ + "python", "main.py", + "--output_dir", eval_dir, + "--do_eval" +], check=True) + +# Perform prediction (inference) +subprocess.run([ + "python", "main.py", + "--output_dir", predict_dir, + "--dataset_name", predict_dataset_name, + "--model_name_or_path", train_dir, + "--do_predict" +], check=True) + +# Print Done +print(f"All Done. Check the output in {predict_dir}") \ No newline at end of file diff --git a/src/CNN_layer_model.py b/src/CNN_layer_model.py new file mode 100644 index 0000000..96beb4c --- /dev/null +++ b/src/CNN_layer_model.py @@ -0,0 +1,117 @@ +from typing import Optional, Union, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import RobertaPreTrainedModel +from transformers.modeling_outputs import QuestionAnsweringModelOutput +from transformers.models.roberta.modeling_roberta import RobertaModel + + +class CNN_block(nn.Module): + def __init__(self, input_size, hidden_size): + super(CNN_block, self).__init__() + self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=input_size, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(in_channels=input_size, out_channels=input_size, kernel_size=1) + self.relu = nn.ReLU() + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, x): + # Transpose the input to match Conv1d input shape (batch_size, channels, sequence_length) + x = x.transpose(1, 2) + output = self.conv1(x) + output = self.conv2(output) + output = x + self.relu(output) + # Transpose back to original shape (batch_size, sequence_length, channels) + output = output.transpose(1, 2) + output = self.layer_norm(output) + return output + +class CNN_RobertaForQuestionAnswering(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roberta = RobertaModel(config, add_pooling_layer=False) + + self.cnn_block1 = CNN_block(config.hidden_size, config.hidden_size) + self.cnn_block2 = CNN_block(config.hidden_size, config.hidden_size) + self.cnn_block3 = CNN_block(config.hidden_size, config.hidden_size) + self.cnn_block4 = CNN_block(config.hidden_size, config.hidden_size) + self.cnn_block5 = CNN_block(config.hidden_size, config.hidden_size) + + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # Apply CNN layers + sequence_output = self.cnn_block1(sequence_output) + sequence_output = self.cnn_block2(sequence_output) + sequence_output = self.cnn_block3(sequence_output) + sequence_output = self.cnn_block4(sequence_output) + sequence_output = self.cnn_block5(sequence_output) + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/src/arguments.py b/src/arguments.py new file mode 100644 index 0000000..913ab0d --- /dev/null +++ b/src/arguments.py @@ -0,0 +1,149 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + # CurtisJeon/klue-roberta-large-korquad_v1_qa + # uomnf97/klue-roberta-finetuned-korquad-v2 + model_name_or_path: str = field( + default="uomnf97/klue-roberta-finetuned-korquad-v2", + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models" + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + config_name_dpr: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + ################################################################################# + batch_size: int = field( + default=16 + ) + + num_epochs: int = field( + default=3 + ) + + ################################################################################# + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default="../data/train_dataset", + metadata={ + "help": "The name of the dataset to use." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={ + "help": "Overwrite the cached training and evaluation sets" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={ + "help": "The number of processes to use for the preprocessing." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " + "be faster on GPU but will be slower on TPU)." + }, + ) + doc_stride: int = field( + default=128, + metadata={ + "help": "When splitting up a long document into chunks, how much stride to take between chunks." + }, + ) + max_answer_length: int = field( + default=64, + metadata={ + "help": "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + }, + ) + eval_retrieval: bool = field( + default=True, + metadata={ + "help": "Whether to run passage retrieval using sparse embedding." + }, + ) + num_clusters: int = field( + default=64, metadata={ + "help": "Define how many clusters to use for faiss." + }, + ) + top_k_retrieval: int = field( + default=20, + metadata={ + "help": "Define how many top-k passages to retrieve based on similarity." + }, + ) + use_faiss: bool = field( + default=True, metadata={ + "help": "Whether to build with faiss" + }, + ) + dense_encoder_type: str = field( + default = 'hybrid', metadata = { + "help": "Whether to run passage retrieval using dense embedding." + }, + ) + remove_char: bool = field( + default=True, metadata={ + "help": "Whether to remove special character before embedding" + }, + ) + data_path: str = field( + default="../data/", + metadata={ + "help": "The path of the data directory" + }, + ) + context_path: str = field( + default="wikipedia_documents.json", + metadata={ + "help": "The name of the context file" + }, + ) + alpha_retrieval: Optional[float] = field( + default=0.7, + metadata={ + "help": "Value for hybridizing embedding scores in HybridSearch" + } + ) \ No newline at end of file diff --git a/src/data_analysis.py b/src/data_analysis.py new file mode 100644 index 0000000..626904f --- /dev/null +++ b/src/data_analysis.py @@ -0,0 +1,103 @@ +import pandas as pd + +from datasets import ( + Dataset, + DatasetDict, + Features, + Value, + Sequence, + load_from_disk, + load_metric +) +from transformers import AutoTokenizer + +from mecab import MeCab +from sklearn.feature_extraction.text import TfidfVectorizer + +data_path = "/data/ephemeral/home/level2-mrc-nlp-15/data/train_dataset" +datasets = load_from_disk(data_path) +print(datasets) + +train_datasets = pd.DataFrame(datasets['train']) +val_datasets = pd.DataFrame(datasets['validation']) + +# val_datasets.head(20) +# for i , data in val_datasets.iterrows(): +# if i == 5: break +# print(data['context']) +# print(data['question']) +# print("-----------------") + +# model_path = 'klue/bert-base' +# tokenizer = AutoTokenizer.from_pretrained(model_path) + +with open('/data/ephemeral/home/level2-mrc-nlp-15/data/stopwords.txt', encoding='utf-8') as f: + stop_words = f.read().splitlines() + +tokenizer = MeCab() + +def filter_stop_words(tokenized_corpus, stop_words: list) -> list: + return [x for x in tokenized_corpus if x not in stop_words] + +tokenize_fn = lambda x: filter_stop_words(tokenizer.morphs(x), stop_words) + +# tokenized_context = train_datasets['context'].map(lambda x: tokenizer.morphs(x)) +# filtered_tokenized_context = [] +# for context in tokenized_context: +# filtered_tokenized_context.append([x for x in context if x not in stop_words]) +# print(filtered_tokenized_context[123]) + +# print("-------------------------------------") + +# tokenized_q = train_datasets['question'].map(lambda x: tokenizer.morphs(x)) +# filtered_tokenized_q = [] +# for q in tokenized_q: +# filtered_tokenized_q.append([x for x in q if x not in stop_words]) +# print(filtered_tokenized_q[123]) + +# ํ•„ํ„ฐ๋ง๋œ ์ฝ”ํผ์Šค์—์„œ ๊ฐ context์— ๋Œ€ํ•ด์„œ tf-idf ๋ฅผ ํ†ตํ•ด์„œ ํ‚ค์›Œ๋“œ ์ถ”์ถœ ํ›„ ํ•ด๋‹น ์งˆ๋ฌธ์— ํ‚ค์›Œ๋“œ๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธํ•˜๊ณ  ์ตœ์ข… ๋น„์œจ ๋„์ถœ + +tfidf = TfidfVectorizer( + tokenizer=tokenize_fn, + token_pattern=None +) + +tfidf_matrix = tfidf.fit_transform(train_datasets['context']) +feature_names = tfidf.get_feature_names_out() + +top_tokens_per_doc = [] + +for doc_idx, doc in enumerate(tfidf_matrix): + # Sparse ๋ฒกํ„ฐ๋ฅผ ๋ฐ€์ง‘ ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜ + doc_array = doc.toarray().flatten() + # TF-IDF ๊ฐ’์ด ํฐ ์ˆœ์„œ๋Œ€๋กœ ์ƒ์œ„ 5๊ฐœ ์ธ๋ฑ์Šค ์ถ”์ถœ + top_n = 5 + if len(doc_array) < top_n: + top_n = len(doc_array) + top_n_idx = doc_array.argsort()[-top_n:][::-1] + # ์ƒ์œ„ ํ† ํฐ๊ณผ ํ•ด๋‹น TF-IDF ๊ฐ’ ์ถ”์ถœ + top_tokens = [feature_names[i] for i in top_n_idx] + top_tfidf_scores = [doc_array[i] for i in top_n_idx] + # ๊ฒฐ๊ณผ ์ €์žฅ + top_tokens_per_doc.append(top_tokens) + # ์ถœ๋ ฅ + # print(f"๋ฌธ์„œ {doc_idx}:") + # for token, score in zip(top_tokens, top_tfidf_scores): + # print(f"๋‹จ์–ด: {token}, TF-IDF ๊ฐ’: {score}") + +tokenized_q = train_datasets['question'].map(lambda x: tokenize_fn(x)) + +cnt = 0 +for idx, q in enumerate(tokenized_q): + ansdoc_tokens = top_tokens_per_doc[idx] + + val = 0 + for token in ansdoc_tokens: + if token in q: + val += 1 + + if val > 0: cnt += 1 + +print(len(top_tokens_per_doc)) +print(cnt) + diff --git a/src/ensemble/probs_voting_ensemble.py b/src/ensemble/probs_voting_ensemble.py new file mode 100644 index 0000000..8d32368 --- /dev/null +++ b/src/ensemble/probs_voting_ensemble.py @@ -0,0 +1,89 @@ +import collections +import argparse +import json +import pandas as pd +from datasets import load_from_disk + +def probs_voting_ensemble(weights, path, number, test_df): + """์ตœ๊ณ  probs ํ•˜๋‚˜๋งŒ์„ ๊ณ ๋ คํ•˜์—ฌ soft emsemble์„ ํ•ด์ฃผ๋Š” ํ•จ์ˆ˜ + + Args: + weights (list): ๊ฐ predictions ๋ณ„ ๊ฐ€์ค‘์น˜ + path (str): prediction์ด ์ €์žฅ๋˜์–ด ์žˆ๋Š” ํด๋” ๊ฒฝ๋กœ + number (int): ensemble ํŒŒ์ผ ๊ฐœ์ˆ˜ + test_df (pd.DataFrame): test ๋ฐ์ดํ„ฐ DataFrame + """ + + test_ids = test_df['id'].tolist() + nbest_prediction = collections.OrderedDict() + prediction = collections.OrderedDict() + weights = [weights[i] / sum(weights) for i in range(len(weights))] + + nbest_hubo = [] + best_hubo = [] + + for i in range(number): + #ํŒŒ์ผ๋ช…์˜ ํ˜•์‹์€ nbest_0, nbest_1, nbest_2์™€ ๊ฐ™์€ ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. + nbest_path = f'{path}/nbest_{i}.json' + best_path = f'{path}/predictions_{i}.json' + + with open(nbest_path, 'r') as json_file: + json_data = json.load(json_file) + nbest_hubo.append(json_data) + with open(best_path, 'r') as json_file: + json_data = json.load(json_file) + best_hubo.append(json_data) + + + for i in range(len(test_ids)): + id = test_ids[i] + max_doc_num = None + max_probs = 0 + + for j in range(number): + pred = nbest_hubo[j][id][0] + score = (pred["probability"]) * weights[j] + + if max_probs <= score: + max_doc_num = j + max_probs = score + + nbest_prediction[id] = nbest_hubo[max_doc_num][id] + prediction[id] = best_hubo[max_doc_num][id] + + nbest_file = f'{path}/nbest_predictions.json' + best_file = f'{path}/predictions.json' + + with open(nbest_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(nbest_prediction, indent=4, ensure_ascii=False) + "\n" + ) + with open(best_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(prediction, indent=4, ensure_ascii=False) + "\n" + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + + parser.add_argument( + #๋น„์œจ์„ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--scores_list", nargs='+', type=float, default=[0.2,0.2,0.2,0.2,0.2], help="list of float" + ) + parser.add_argument( + #ํด๋” ๊ฒฝ๋กœ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--folder_path", default=f"/data/ephemeral/home/level2-mrc-nlp-15/ensemble/nbest", type=str, help="folder path" + ) + parser.add_argument( + #์•™์ƒ๋ธ” ํŒŒ์ผ ๊ฐœ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--file_number", default=5, type=int, help="ensemble file number" + ) + + #test_dataset ๊ฒฝ๋กœ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + test_dataset = load_from_disk("/data/ephemeral/home/level2-mrc-nlp-15/data/test_dataset") + test_df = pd.DataFrame(test_dataset['validation']) + + args = parser.parse_args() + + probs_voting_ensemble(args.scores_list, args.folder_path, args.file_number, test_df) + \ No newline at end of file diff --git a/src/ensemble/probs_voting_ensemble_n.py b/src/ensemble/probs_voting_ensemble_n.py new file mode 100644 index 0000000..892cb29 --- /dev/null +++ b/src/ensemble/probs_voting_ensemble_n.py @@ -0,0 +1,78 @@ +import collections +import argparse +import json +import pandas as pd +from datasets import load_from_disk + +def probs_voting_ensemble_n(weights, path, number, nbest, test_df): + """ nbest์˜ probs๋ฅผ ๊ณ ๋ คํ•˜์—ฌ soft emsemble์„ ํ•ด์ฃผ๋Š” ํ•จ์ˆ˜ + + Args: + weights (list): ๊ฐ nbest_predictions ๋ณ„ ๊ฐ€์ค‘์น˜ + path (str): nbest_prediction์ด ์ €์žฅ๋˜์–ด ์žˆ๋Š” ํด๋” ๊ฒฝ๋กœ + number (int): ensemble ํŒŒ์ผ ๊ฐœ์ˆ˜ + nbest (int): ๋ช‡ ๊ฐœ์˜ nbest๊นŒ์ง€ ensemble์— ๊ณ ๋ คํ•  ๊ฒƒ์ธ์ง€์˜ ๊ฐœ์ˆ˜ + test_df (pd.DataFrame): test ๋ฐ์ดํ„ฐ DataFrame + """ + + test_ids = test_df['id'].tolist() + prediction = collections.OrderedDict() + weights = [weights[i] / sum(weights) for i in range(len(weights))] + + nbest_hubo = [] + + for i in range(number): + #ํŒŒ์ผ๋ช…์˜ ํ˜•์‹์€ nbest_0, nbest_1, nbest_2์™€ ๊ฐ™์€ ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. + nbest_path = f'{path}/nbest_{i}.json' + + with open(nbest_path, 'r') as json_file: + json_data = json.load(json_file) + nbest_hubo.append(json_data) + + for id in test_ids: + hubo = collections.defaultdict(float) + + for i in range(number): + preds = nbest_hubo[i][id][:nbest] + for pred in preds: + hubo[pred["text"]] += pred["probability"] * weights[i] + + max_text = max(hubo, key=hubo.get) + prediction[id] = max_text + + + best_file = f'{path}/predictions.json' + + with open(best_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(prediction, indent=4, ensure_ascii=False) + "\n" + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + + parser.add_argument( + #๋น„์œจ์„ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--scores_list", nargs='+', type=float, default=[0.2,0.2,0.2,0.2,0.2], help="list of float" + ) + parser.add_argument( + #ํด๋” ๊ฒฝ๋กœ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--folder_path", default=f"/data/ephemeral/home/level2-mrc-nlp-15/ensemble/nbest", type=str, help="folder path" + ) + parser.add_argument( + #์•™์ƒ๋ธ” ํŒŒ์ผ ๊ฐœ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--file_number", default=5, type=int, help="ensemble file number" + ) + parser.add_argument( + #nbest๋กœ ์›ํ•˜๋Š” ๊ฐœ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--nbest", default=3, type=int, help="nbest to include" + ) + + #test_dataset ๊ฒฝ๋กœ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + test_dataset = load_from_disk("/data/ephemeral/home/level2-mrc-nlp-15/data/test_dataset") + test_df = pd.DataFrame(test_dataset['validation']) + + args = parser.parse_args() + + probs_voting_ensemble_n(args.scores_list, args.folder_path, args.file_number, args.nbest, test_df) + \ No newline at end of file diff --git a/src/ensemble/scores_voting_ensemble.py b/src/ensemble/scores_voting_ensemble.py new file mode 100644 index 0000000..6dd5a74 --- /dev/null +++ b/src/ensemble/scores_voting_ensemble.py @@ -0,0 +1,91 @@ +import collections +import argparse +import json +import pandas as pd +from datasets import load_from_disk + +def scores_voting_ensemble(weights, path, number, test_df): + """์ตœ๊ณ  logits ํ•˜๋‚˜๋งŒ์„ ๊ณ ๋ คํ•˜์—ฌ soft emsemble์„ ํ•ด์ฃผ๋Š” ํ•จ์ˆ˜ + + Args: + weights (list): ๊ฐ predictions ๋ณ„ ๊ฐ€์ค‘์น˜ + path (str): prediction์ด ์ €์žฅ๋˜์–ด ์žˆ๋Š” ํด๋” ๊ฒฝ๋กœ + number (int): ensemble ํŒŒ์ผ ๊ฐœ์ˆ˜ + test_df (pd.DataFrame): test ๋ฐ์ดํ„ฐ DataFrame + """ + test_ids = test_df['id'].tolist() + nbest_prediction = collections.OrderedDict() + prediction = collections.OrderedDict() + weights = [weights[i] / sum(weights) for i in range(len(weights))] + + nbest_hubo = [] + best_hubo = [] + + for i in range(number): + #ํŒŒ์ผ๋ช…์˜ ํ˜•์‹์€ nbest_0, nbest_1, nbest_2์™€ ๊ฐ™์€ ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค. + nbest_path = f'{path}/nbest_{i}.json' + best_path = f'{path}/predictions_{i}.json' + + with open(nbest_path, 'r') as json_file: + json_data = json.load(json_file) + nbest_hubo.append(json_data) + with open(best_path, 'r') as json_file: + json_data = json.load(json_file) + best_hubo.append(json_data) + + + for i in range(len(test_ids)): + id = test_ids[i] + max_doc_num = None + max_logits = -200 + + for j in range(number): + pred = nbest_hubo[j][id][0] + score = (pred['start_logit'] + pred['end_logit']) + + if score < 0: + score = score * (1-weights[j]) + else: + score = score * weights[j] + + if max_logits <= score: + max_doc_num = j + max_logits = score + + nbest_prediction[id] = nbest_hubo[max_doc_num][id] + prediction[id] = best_hubo[max_doc_num][id] + + nbest_file = f'{path}/nbest_predictions.json' + best_file = f'{path}/predictions.json' + + with open(nbest_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(nbest_prediction, indent=4, ensure_ascii=False) + "\n" + ) + with open(best_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(prediction, indent=4, ensure_ascii=False) + "\n" + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + + parser.add_argument( + #๋น„์œจ์„ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--scores_list", nargs='+', type=float, default=[0.2,0.2,0.2,0.2,0.2], help="list of float" + ) + parser.add_argument( + #ํด๋” ๊ฒฝ๋กœ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--folder_path", default=f"/data/ephemeral/home/level2-mrc-nlp-15/ensemble/nbest", type=str, help="folder path" + ) + parser.add_argument( + #์•™์ƒ๋ธ” ํŒŒ์ผ ๊ฐœ์ˆ˜๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + "--file_number", default=5, type=int, help="ensemble file number" + ) + + #test_dataset ๊ฒฝ๋กœ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค. + test_dataset = load_from_disk("/data/ephemeral/home/level2-mrc-nlp-15/data/test_dataset") + test_df = pd.DataFrame(test_dataset['validation']) + + args = parser.parse_args() + \ No newline at end of file diff --git a/src/korquad_finetuning_v2.ipynb b/src/korquad_finetuning_v2.ipynb new file mode 100644 index 0000000..6ace7b4 --- /dev/null +++ b/src/korquad_finetuning_v2.ipynb @@ -0,0 +1,3376 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# KorQuAD 1.0 ํ™œ์šฉํ•ด์„œ Roberta-large fine-tuning ํ•˜๊ธฐ\n", + "-> huggingface์— ๋ชจ๋ธ ์˜ฌ๋ ค๋‘๊ณ  ๋ถˆ๋Ÿฌ์™€์„œ ์‚ฌ์šฉํ•˜๊ธฐ! https://www.youtube.com/watch?v=ovD_87gHZO4\n", + "\n", + "- Method0. CNN layer ์ถ”๊ฐ€ \n", + "- Method1. KorQuAD 1.0 (train+validation) ๋ฐ์ดํ„ฐ์…‹ ๋งŒ์œผ๋กœ 1์ฐจ fine-tuning -> ssunbear/klue_roberta_large_finetuned_korquad_v1\n", + "- Method2. Method1์— mrc_train ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•œ๋ฒˆ ๋” fine-tuning(๋ชจ๋ธ ์žฌํ˜ธ์ถœ) -> ssunbear/klue_roberta_large_finetuned_korquad_v2" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: huggingface in /opt/conda/lib/python3.10/site-packages (0.0.1)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: datasets==2.14.6 in /opt/conda/lib/python3.10/site-packages (2.14.6)\n", + "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (1.26.0)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (17.0.0)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (0.3.7)\n", + "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (2.2.3)\n", + "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (4.66.5)\n", + "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (3.5.0)\n", + "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (0.70.15)\n", + "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (2023.9.2)\n", + "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (3.10.10)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (0.25.2)\n", + "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (23.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets==2.14.6) (6.0)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (2.4.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (23.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (6.1.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.12.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (1.15.2)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.14.6) (4.0.3)\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.6) (3.9.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets==2.14.6) (4.7.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.14.6) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.14.6) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.14.6) (1.26.16)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.14.6) (2024.8.30)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets==2.14.6) (2.9.0)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets==2.14.6) (2023.3.post1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets==2.14.6) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets==2.14.6) (1.16.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.10/site-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets==2.14.6) (0.2.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: wandb in /opt/conda/lib/python3.10/site-packages (0.18.3)\n", + "Requirement already satisfied: click!=8.0.0,>=7.1 in /opt/conda/lib/python3.10/site-packages (from wandb) (8.0.4)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /opt/conda/lib/python3.10/site-packages (from wandb) (4.3.6)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (5.28.2)\n", + "Requirement already satisfied: psutil>=5.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (5.9.0)\n", + "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from wandb) (6.0)\n", + "Requirement already satisfied: requests<3,>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (2.32.3)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from wandb) (2.16.0)\n", + "Requirement already satisfied: setproctitle in /opt/conda/lib/python3.10/site-packages (from wandb) (1.3.3)\n", + "Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from wandb) (68.0.0)\n", + "Requirement already satisfied: six>=1.4.0 in /opt/conda/lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /opt/conda/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (1.26.16)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2024.8.30)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /opt/conda/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install transformers==4.24.0 -q\n", + "!pip install huggingface\n", + "!pip install datasets==2.14.6\n", + "!pip install wandb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## korquad ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from datasets import load_dataset, concatenate_datasets\n", + "\n", + "dataset = load_dataset('squad_kor_v1')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['id', 'title', 'context', 'question', 'answers'],\n", + " num_rows: 60407\n", + "})" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset['train']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['id', 'title', 'context', 'question', 'answers'],\n", + " num_rows: 5774\n", + "})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset['validation']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## mrc valid ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_from_disk\n", + "\n", + "# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ\n", + "mrc_train_dataset_path = \"/data/ephemeral/home/level2-mrc-nlp-15/data/train_dataset/train\" # ์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •\n", + "mrc_train_dataset = load_from_disk(mrc_train_dataset_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],\n", + " num_rows: 3952\n", + "})" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mrc_train_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_from_disk\n", + "\n", + "# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ\n", + "mrc_validation_dataset_path = \"/data/ephemeral/home/level2-mrc-nlp-15/data/train_dataset/validation\" # ์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •\n", + "mrc_validation_dataset = load_from_disk(mrc_validation_dataset_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],\n", + " num_rows: 240\n", + "})" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mrc_validation_dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_from_disk\n", + "\n", + "# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ\n", + "mrc_train_validation_path = \"/data/ephemeral/home/level2-mrc-nlp-15/data/train_dataset/validation\" # ์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •\n", + "mrc_train_validation = load_from_disk(mrc_train_validation_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# korquad ๋ฐ์ดํ„ฐ์…‹์ด๋ž‘ ํ˜•์‹ ๋˜‘๊ฐ™์ด ๋งŒ๋“ค์–ด์ฃผ๊ธฐ\n", + "id_list = []\n", + "title_list = []\n", + "context_list = []\n", + "question_list = []\n", + "answers_list = []\n", + "\n", + "for index, row in pd.DataFrame(mrc_train_dataset).iterrows():\n", + " id_list.append(row['id'])\n", + " title_list.append(str(row['title']))\n", + " context_list.append(str(row['context']))\n", + " question_list.append(str(row['question']))\n", + " answers_list.append(row['answers'])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "mrc_train_dataset = {\n", + " \"id\" : id_list,\n", + " \"title\" : title_list,\n", + " \"context\" : context_list,\n", + " \"question\" : question_list,\n", + " \"answers\" : answers_list,}" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'Dataset' object has no attribute 'items'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Dataset\n\u001b[0;32m----> 3\u001b[0m mrc_train_dataset\u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmrc_train_dataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m mrc_train_dataset\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:900\u001b[0m, in \u001b[0;36mDataset.from_dict\u001b[0;34m(cls, mapping, features, info, split)\u001b[0m\n\u001b[1;32m 898\u001b[0m features \u001b[38;5;241m=\u001b[39m features \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m info\u001b[38;5;241m.\u001b[39mfeatures \u001b[38;5;28;01mif\u001b[39;00m info \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 899\u001b[0m arrow_typed_mapping \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 900\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m col, data \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmapping\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m():\n\u001b[1;32m 901\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, (pa\u001b[38;5;241m.\u001b[39mArray, pa\u001b[38;5;241m.\u001b[39mChunkedArray)):\n\u001b[1;32m 902\u001b[0m data \u001b[38;5;241m=\u001b[39m cast_array_to_feature(data, features[col]) \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m data\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Dataset' object has no attribute 'items'" + ] + } + ], + "source": [ + "from datasets import Dataset\n", + "\n", + "mrc_train_dataset= Dataset.from_dict(mrc_train_dataset)\n", + "mrc_train_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "# korquad ๋ฐ์ดํ„ฐ์…‹์ด๋ž‘ ํ˜•์‹ ๋˜‘๊ฐ™์ด ๋งŒ๋“ค์–ด์ฃผ๊ธฐ\n", + "id_list2 = []\n", + "title_list2 = []\n", + "context_list2 = []\n", + "question_list2 = []\n", + "answers_list2 = []\n", + "\n", + "for index, row in pd.DataFrame(mrc_validation_dataset).iterrows():\n", + " id_list2.append(row['id'])\n", + " title_list2.append(str(row['title']))\n", + " context_list2.append(str(row['context']))\n", + " question_list2.append(str(row['question']))\n", + " answers_list2.append(row['answers'])" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "mrc_validation_dataset = {\n", + " \"id\" : id_list,\n", + " \"title\" : title_list,\n", + " \"context\" : context_list,\n", + " \"question\" : question_list,\n", + " \"answers\" : answers_list,}" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'Dataset' object has no attribute 'items'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[67], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Dataset\n\u001b[0;32m----> 3\u001b[0m mrc_validation_dataset\u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmrc_validation_dataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m mrc_validation_dataset\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:900\u001b[0m, in \u001b[0;36mDataset.from_dict\u001b[0;34m(cls, mapping, features, info, split)\u001b[0m\n\u001b[1;32m 898\u001b[0m features \u001b[38;5;241m=\u001b[39m features \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m info\u001b[38;5;241m.\u001b[39mfeatures \u001b[38;5;28;01mif\u001b[39;00m info \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 899\u001b[0m arrow_typed_mapping \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m--> 900\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m col, data \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmapping\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m():\n\u001b[1;32m 901\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, (pa\u001b[38;5;241m.\u001b[39mArray, pa\u001b[38;5;241m.\u001b[39mChunkedArray)):\n\u001b[1;32m 902\u001b[0m data \u001b[38;5;241m=\u001b[39m cast_array_to_feature(data, features[col]) \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m data\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Dataset' object has no attribute 'items'" + ] + } + ], + "source": [ + "from datasets import Dataset\n", + "\n", + "mrc_validation_dataset= Dataset.from_dict(mrc_validation_dataset)\n", + "mrc_validation_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## korquad ๋ฐ์ดํ„ฐ์…‹ filtering\n", + "- Korquad ๋ฐ์ดํ„ฐ์…‹๊ณผ train ๋ฐ์ดํ„ฐ์…‹์˜ context ๊ธธ์ด ๋ถ„ํฌ ๋งž์ถฐ์ฃผ๊ธฐ\n", + "- train ๋ฐ์ดํ„ฐ์…‹ context ๊ธธ์ด 2064 ์ดํ•˜์ด๋ฏ€๋กœ, korquad ๋ฐ์ดํ„ฐ์…‹ ์ค‘ context ๊ธธ์ด๊ฐ€ 2064๊ฐœ ์ด์ƒ์ธ ๋ฐ์ดํ„ฐ๋“ค์€ ์‚ญ์ œํ•ด์ค๋‹ˆ๋‹ค." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "filtered_dataset = dataset['train'].filter(lambda example: len(example['context']) <= 2064)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "filtered_dataset_validation = dataset['validation'].filter(lambda example: len(example['context']) <= 2064)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "id_list3 = []\n", + "title_list3 = []\n", + "context_list3 = []\n", + "question_list3 = []\n", + "answers_list3 = []\n", + "\n", + "for index, row in pd.DataFrame(filtered_dataset_validation).iterrows():\n", + " id_list3.append(row['id'])\n", + " title_list3.append(str(row['title']))\n", + " context_list3.append(str(row['context']))\n", + " question_list3.append(str(row['question']))\n", + " answers_list3.append(row['answers'])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "val_dataset = {\n", + " \"id\" : id_list3,\n", + " \"title\" : title_list3,\n", + " \"context\" : context_list3,\n", + " \"question\" : question_list3,\n", + " \"answers\" : answers_list3,}" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['id', 'title', 'context', 'question', 'answers'],\n", + " num_rows: 5735\n", + "})" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "val_dataset= Dataset.from_dict(val_dataset)\n", + "val_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-trained ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Some weights of the model checkpoint at klue/roberta-large were not used when initializing CNN_RobertaForQuestionAnswering: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias']\n", + "- This IS expected if you are initializing CNN_RobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing CNN_RobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of CNN_RobertaForQuestionAnswering were not initialized from the model checkpoint at klue/roberta-large and are newly initialized: ['cnn_block2.conv1.weight', 'cnn_block1.layer_norm.weight', 'cnn_block1.conv2.bias', 'cnn_block5.conv1.bias', 'cnn_block3.conv1.bias', 'cnn_block5.conv2.bias', 'cnn_block4.conv1.bias', 'cnn_block2.conv2.bias', 'cnn_block5.conv1.weight', 'cnn_block2.layer_norm.weight', 'cnn_block3.conv1.weight', 'cnn_block5.layer_norm.weight', 'cnn_block4.conv1.weight', 'cnn_block1.conv1.weight', 'cnn_block3.layer_norm.bias', 'cnn_block4.conv2.bias', 'qa_outputs.bias', 'qa_outputs.weight', 'cnn_block2.conv1.bias', 'cnn_block2.conv2.weight', 'cnn_block1.conv1.bias', 'cnn_block5.conv2.weight', 'cnn_block3.conv2.weight', 'cnn_block4.layer_norm.weight', 'cnn_block4.layer_norm.bias', 'cnn_block5.layer_norm.bias', 'cnn_block2.layer_norm.bias', 'cnn_block1.layer_norm.bias', 'cnn_block1.conv2.weight', 'cnn_block3.conv2.bias', 'cnn_block4.conv2.weight', 'cnn_block3.layer_norm.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "from transformers import (\n", + " AutoConfig,\n", + " #AutoModelForQuestionAnswering,\n", + " AutoTokenizer\n", + ")\n", + "from CNN_layer_model import CNN_RobertaForQuestionAnswering\n", + "\n", + "model_name = \"klue/roberta-large\"\n", + "\n", + "config = AutoConfig.from_pretrained(\n", + " model_name\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " model_name,\n", + " use_fast=True\n", + ")\n", + "model = CNN_RobertaForQuestionAnswering.from_pretrained(\n", + " model_name,\n", + " config=config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Korquad ๋ฐ์ดํ„ฐ์…‹ ์ „์ฒ˜๋ฆฌ\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "max_seq_length = 512 # ์งˆ๋ฌธ๊ณผ ์ปจํ…์ŠคํŠธ, special token์„ ํ•ฉํ•œ ๋ฌธ์ž์—ด์˜ ์ตœ๋Œ€ ๊ธธ์ด (์ผ์ • ๊ฐœ์ˆ˜๊ฐ€ ๋„˜์–ด๊ฐ€์ง€ ์•Š๋„๋ก!)\n", + "pad_to_max_length = False\n", + "doc_stride = 128 # ์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์„œ ๋‚˜๋ˆด์„ ๋•Œ ์˜ค๋ฒ„๋žฉ๋˜๋Š” ์‹œํ€€์Šค ๊ธธ์ด, ๋ฌธ์„œ 2๊ฐœ๋กœ ์ชผ๊ฐœ๊ณ , 128๊ฐœ ์‹œํ€€์Šค๊ฐ€ ๊ฒน์น˜๋„๋ก\n", + "preprocessing_num_workers = None\n", + "batch_size = 16\n", + "num_train_epochs = 1\n", + "n_best_size = 20\n", + "max_answer_length = 30" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_train_features(examples): # examples: ๋ฐ์ดํ„ฐ์…‹ row..\n", + " # ์ฃผ์–ด์ง„ ํ…์ŠคํŠธ๋ฅผ ํ† ํฌ๋‚˜์ด์ง• ํ•œ๋‹ค. ์ด ๋•Œ ํ…์ŠคํŠธ์˜ ๊ธธ์ด๊ฐ€ max_seq_length๋ฅผ ๋„˜์œผ๋ฉด stride๋งŒํผ ์Šฌ๋ผ์ด๋”ฉํ•˜๋ฉฐ ์—ฌ๋Ÿฌ ๊ฐœ๋กœ ์ชผ๊ฐฌ.\n", + " # ์ฆ‰, ํ•˜๋‚˜์˜ example์—์„œ ์ผ๋ถ€๋ถ„์ด ๊ฒน์น˜๋Š” ์—ฌ๋Ÿฌ sequence(feature)๊ฐ€ ์ƒ๊ธธ ์ˆ˜ ์žˆ์Œ.\n", + " tokenized_examples = tokenizer(\n", + " examples[\"question\"],\n", + " examples[\"context\"],\n", + " truncation=\"only_second\", # max_seq_length๊นŒ์ง€ truncateํ•œ๋‹ค. pair์˜ ๋‘๋ฒˆ์งธ ํŒŒํŠธ(context)๋งŒ ์ž˜๋ผ๋ƒ„.\n", + " max_length=max_seq_length,\n", + " stride=doc_stride,\n", + " return_overflowing_tokens=True, # ๊ธธ์ด๋ฅผ ๋„˜์–ด๊ฐ€๋Š” ํ† ํฐ๋“ค์„ ๋ฐ˜ํ™˜ํ•  ๊ฒƒ์ธ์ง€\n", + " return_offsets_mapping=True, # ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด (char_start, char_end) ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜ํ•œ ๊ฒƒ์ธ์ง€\n", + " padding=\"max_length\", return_token_type_ids=False\n", + " )\n", + "\n", + " # example ํ•˜๋‚˜๊ฐ€ ์—ฌ๋Ÿฌ sequence์— ๋Œ€์‘ํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์œ„ํ•ด ๋งคํ•‘์ด ํ•„์š”ํ•จ.\n", + " overflow_to_sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n", + " # offset_mappings์œผ๋กœ ํ† ํฐ์ด ์›๋ณธ context ๋‚ด ๋ช‡๋ฒˆ์งธ ๊ธ€์ž๋ถ€ํ„ฐ ๋ช‡๋ฒˆ์งธ ๊ธ€์ž๊นŒ์ง€ ํ•ด๋‹นํ•˜๋Š”์ง€ ์•Œ ์ˆ˜ ์žˆ์Œ.\n", + " offset_mapping = tokenized_examples.pop(\"offset_mapping\")\n", + "\n", + " # ์ •๋‹ต์ง€๋ฅผ ๋งŒ๋“ค๊ธฐ ์œ„ํ•œ ๋ฆฌ์ŠคํŠธ\n", + " tokenized_examples[\"start_positions\"] = []\n", + " tokenized_examples[\"end_positions\"] = []\n", + "\n", + " for i, offsets in enumerate(offset_mapping):\n", + " input_ids = tokenized_examples[\"input_ids\"][i]\n", + " cls_index = input_ids.index(tokenizer.cls_token_id)\n", + "\n", + " # ํ•ด๋‹น example์— ํ•ด๋‹นํ•˜๋Š” sequence๋ฅผ ์ฐพ์Œ.\n", + " sequence_ids = tokenized_examples.sequence_ids(i)\n", + "\n", + " # sequence๊ฐ€ ์†ํ•˜๋Š” example์„ ์ฐพ๋Š”๋‹ค.\n", + " example_index = overflow_to_sample_mapping[i]\n", + " answers = examples[\"answers\"][example_index]\n", + "\n", + " # ํ…์ŠคํŠธ์—์„œ answer์˜ ์‹œ์ž‘์ , ๋์ \n", + " answer_start_offset = answers[\"answer_start\"][0]\n", + " answer_end_offset = answer_start_offset + len(answers[\"text\"][0])\n", + "\n", + " # ํ…์ŠคํŠธ์—์„œ ํ˜„์žฌ span์˜ ์‹œ์ž‘ ํ† ํฐ ์ธ๋ฑ์Šค\n", + " token_start_index = 0\n", + " while sequence_ids[token_start_index] != 1:\n", + " token_start_index += 1\n", + "\n", + " # ํ…์ŠคํŠธ์—์„œ ํ˜„์žฌ span ๋ ํ† ํฐ ์ธ๋ฑ์Šค\n", + " token_end_index = len(input_ids) - 1\n", + " while sequence_ids[token_end_index] != 1:\n", + " token_end_index -= 1\n", + "\n", + " # answer๊ฐ€ ํ˜„์žฌ span์„ ๋ฒ—์–ด๋‚ฌ๋Š”์ง€ ์ฒดํฌ\n", + " if not (offsets[token_start_index][0] <= answer_start_offset and offsets[token_end_index][1] >= answer_end_offset):\n", + " tokenized_examples[\"start_positions\"].append(cls_index)\n", + " tokenized_examples[\"end_positions\"].append(cls_index)\n", + " else:\n", + " # token_start_index์™€ token_end_index๋ฅผ answer์˜ ์‹œ์ž‘์ ๊ณผ ๋์ ์œผ๋กœ ์˜ฎ๊น€\n", + " while token_start_index < len(offsets) and offsets[token_start_index][0] <= answer_start_offset:\n", + " token_start_index += 1\n", + " tokenized_examples[\"start_positions\"].append(token_start_index - 1)\n", + " while offsets[token_end_index][1] >= answer_end_offset:\n", + " token_end_index -= 1\n", + " tokenized_examples[\"end_positions\"].append(token_end_index + 1)\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "column_names = filtered_dataset.column_names\n", + "train_dataset = filtered_dataset.map(\n", + " prepare_train_features,\n", + " batched=True,\n", + " num_proc=preprocessing_num_workers,\n", + " remove_columns=column_names,\n", + " load_from_cache_file=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_validation_features(examples):\n", + " tokenized_examples = tokenizer(\n", + " examples['question'],\n", + " examples['context'],\n", + " truncation=\"only_second\",\n", + " max_length=max_seq_length,\n", + " stride=doc_stride,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " )\n", + "\n", + " sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n", + "\n", + " tokenized_examples[\"example_id\"] = []\n", + "\n", + " for i in range(len(tokenized_examples[\"input_ids\"])):\n", + " sequence_ids = tokenized_examples.sequence_ids(i)\n", + " context_index = 1\n", + "\n", + " sample_index = sample_mapping[i]\n", + " tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\n", + "\n", + " tokenized_examples[\"offset_mapping\"][i] = [\n", + " (o if sequence_ids[k] == context_index else None)\n", + " for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\n", + " ]\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5735/5735 [00:04<00:00, 1429.30 examples/s]\n" + ] + } + ], + "source": [ + "eval_dataset = val_dataset.map(\n", + " prepare_validation_features,\n", + " batched=True,\n", + " num_proc=preprocessing_num_workers,\n", + " remove_columns=column_names,\n", + " load_from_cache_file=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Question Answering Class ์ •์˜\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "# default_data_collator: ์—ฌ๋Ÿฌ๊ฐœ example๋“ค์„ collatorํ•ด์ฃผ๋Š” ์—ญํ• ,\n", + "# TrainingArguments : ํ•œ๋ฒˆ์— training arguments๋“ค์„ ํ•ฉ์ณ์„œ ์ฃผ๋Š”..!\n", + "from transformers import default_data_collator, TrainingArguments, EvalPrediction" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "# coding=utf-8\n", + "# Copyright 2020 The HuggingFace Team All rights reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\"\"\"\n", + "Question-Answering task์™€ ๊ด€๋ จ๋œ 'Trainer'์˜ subclass ์ฝ”๋“œ ์ž…๋‹ˆ๋‹ค.\n", + "\"\"\"\n", + "\n", + "from transformers import Trainer, is_datasets_available, is_torch_tpu_available\n", + "from transformers.trainer_utils import PredictionOutput\n", + "\n", + "if is_datasets_available():\n", + " import datasets\n", + "\n", + "# Huggingface์˜ Trainer๋ฅผ ์ƒ์†๋ฐ›์•„ QuestionAnswering์„ ์œ„ํ•œ Trainer๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n", + "class QuestionAnsweringTrainer(Trainer):\n", + " def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.eval_examples = eval_examples\n", + " self.post_process_function = post_process_function\n", + "\n", + " def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):\n", + " eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset\n", + " eval_dataloader = self.get_eval_dataloader(eval_dataset)\n", + " eval_examples = self.eval_examples if eval_examples is None else eval_examples\n", + "\n", + " # ์ผ์‹œ์ ์œผ๋กœ metric computation๋ฅผ ๋ถˆ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ ์ƒํƒœ์ด๋ฉฐ, ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” loop ๋‚ด์—์„œ metric ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.\n", + " compute_metrics = self.compute_metrics\n", + " self.compute_metrics = None\n", + " try:\n", + " output = self.prediction_loop(\n", + " eval_dataloader,\n", + " description=\"Evaluation\",\n", + " # metric์ด ์—†์œผ๋ฉด ์˜ˆ์ธก๊ฐ’์„ ๋ชจ์œผ๋Š” ์ด์œ ๊ฐ€ ์—†์œผ๋ฏ€๋กœ ์•„๋ž˜์˜ ์ฝ”๋“œ๋ฅผ ๋”ฐ๋ฅด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n", + " # self.args.prediction_loss_only\n", + " prediction_loss_only=True if compute_metrics is None else None,\n", + " ignore_keys=ignore_keys,\n", + " )\n", + " finally:\n", + " self.compute_metrics = compute_metrics\n", + "\n", + " if isinstance(eval_dataset, datasets.Dataset):\n", + " eval_dataset.set_format(\n", + " type=eval_dataset.format[\"type\"],\n", + " columns=list(eval_dataset.features.keys()),\n", + " )\n", + "\n", + " if self.post_process_function is not None and self.compute_metrics is not None:\n", + " eval_preds = self.post_process_function(\n", + " eval_examples, eval_dataset, output.predictions, self.args\n", + " )\n", + " metrics = self.compute_metrics(eval_preds)\n", + "\n", + " self.log(metrics)\n", + " else:\n", + " metrics = {}\n", + "\n", + " self.control = self.callback_handler.on_evaluate(\n", + " self.args, self.state, self.control, metrics\n", + " )\n", + " return metrics\n", + "\n", + " def predict(self, test_dataset, test_examples, ignore_keys=None):\n", + " test_dataloader = self.get_test_dataloader(test_dataset)\n", + "\n", + " # ์ผ์‹œ์ ์œผ๋กœ metric computation๋ฅผ ๋ถˆ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ ์ƒํƒœ์ด๋ฉฐ, ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” loop ๋‚ด์—์„œ metric ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.\n", + " # evaluate ํ•จ์ˆ˜์™€ ๋™์ผํ•˜๊ฒŒ ๊ตฌ์„ฑ๋˜์–ด์žˆ์Šต๋‹ˆ๋‹ค\n", + " compute_metrics = self.compute_metrics\n", + " self.compute_metrics = None\n", + " try:\n", + " output = self.prediction_loop(\n", + " test_dataloader,\n", + " description=\"Evaluation\",\n", + " # metric์ด ์—†์œผ๋ฉด ์˜ˆ์ธก๊ฐ’์„ ๋ชจ์œผ๋Š” ์ด์œ ๊ฐ€ ์—†์œผ๋ฏ€๋กœ ์•„๋ž˜์˜ ์ฝ”๋“œ๋ฅผ ๋”ฐ๋ฅด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n", + " # self.args.prediction_loss_only\n", + " prediction_loss_only=True if compute_metrics is None else None,\n", + " ignore_keys=ignore_keys,\n", + " )\n", + " finally:\n", + " self.compute_metrics = compute_metrics\n", + "\n", + " if self.post_process_function is None or self.compute_metrics is None:\n", + " return output\n", + "\n", + " if isinstance(test_dataset, datasets.Dataset):\n", + " test_dataset.set_format(\n", + " type=test_dataset.format[\"type\"],\n", + " columns=list(test_dataset.features.keys()),\n", + " )\n", + "\n", + " predictions = self.post_process_function(\n", + " test_examples, test_dataset, output.predictions, self.args\n", + " )\n", + " return predictions\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ํ›„์ฒ˜๋ฆฌ ํด๋ž˜์Šค ์ •์˜\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "# coding=utf-8\n", + "# Copyright 2020 The HuggingFace Team All rights reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the 'License');\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an 'AS IS' BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\"\"\"\n", + "Pre-processing\n", + "Post-processing utilities for question answering.\n", + "\"\"\"\n", + "import collections\n", + "import json\n", + "import logging\n", + "import os\n", + "import random\n", + "from typing import Any, Optional, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from arguments import DataTrainingArguments, ModelArguments\n", + "from datasets import DatasetDict\n", + "from tqdm.auto import tqdm\n", + "from transformers import PreTrainedTokenizerFast, TrainingArguments, is_torch_available\n", + "from transformers.trainer_utils import get_last_checkpoint\n", + "\n", + "#from utils.datetime_helper import get_seoul_datetime_str\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "def set_seed(seed: int = 2024):\n", + " \"\"\"\n", + " seed ๊ณ ์ •ํ•˜๋Š” ํ•จ์ˆ˜ (random, numpy, torch)\n", + "\n", + " Args:\n", + " seed (:obj:`int`): The seed to set.\n", + " \"\"\"\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " if is_torch_available():\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed) # if use multi-GPU\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "\n", + "def postprocess_qa_predictions(\n", + " examples,\n", + " features,\n", + " predictions: Tuple[np.ndarray, np.ndarray],\n", + " version_2_with_negative: bool = False,\n", + " n_best_size: int = 20,\n", + " max_answer_length: int = 30,\n", + " null_score_diff_threshold: float = 0.0,\n", + " output_dir: Optional[str] = None,\n", + " prefix: Optional[str] = None,\n", + " is_world_process_zero: bool = True,\n", + "):\n", + " \"\"\"\n", + " Post-processes : qa model์˜ prediction ๊ฐ’์„ ํ›„์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜\n", + " ๋ชจ๋ธ์€ start logit๊ณผ end logit์„ ๋ฐ˜ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์ด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ original text๋กœ ๋ณ€๊ฒฝํ•˜๋Š” ํ›„์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•จ\n", + "\n", + " Args:\n", + " examples: ์ „์ฒ˜๋ฆฌ ๋˜์ง€ ์•Š์€ ๋ฐ์ดํ„ฐ์…‹ (see the main script for more information).\n", + " features: ์ „์ฒ˜๋ฆฌ๊ฐ€ ์ง„ํ–‰๋œ ๋ฐ์ดํ„ฐ์…‹ (see the main script for more information).\n", + " predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):\n", + " ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’ :start logits๊ณผ the end logits์„ ๋‚˜ํƒ€๋‚ด๋Š” two arrays ์ฒซ๋ฒˆ์งธ ์ฐจ์›์€ :obj:`features`์˜ element์™€ ๊ฐฏ์ˆ˜๊ฐ€ ๋งž์•„์•ผํ•จ.\n", + " version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):\n", + " ์ •๋‹ต์ด ์—†๋Š” ๋ฐ์ดํ„ฐ์…‹์ด ํฌํ•จ๋˜์–ด์žˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ๋‚˜ํƒ€๋ƒ„\n", + " n_best_size (:obj:`int`, `optional`, defaults to 20):\n", + " ๋‹ต๋ณ€์„ ์ฐพ์„ ๋•Œ ์ƒ์„ฑํ•  n-best prediction ์ด ๊ฐœ์ˆ˜\n", + " max_answer_length (:obj:`int`, `optional`, defaults to 30):\n", + " ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ๊ธธ์ด\n", + " null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):\n", + " null ๋‹ต๋ณ€์„ ์„ ํƒํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” threshold\n", + " : if the best answer has a score that is less than the score of\n", + " the null answer minus this threshold, the null answer is selected for this example (note that the score of\n", + " the null answer for an example giving several features is the minimum of the scores for the null answer on\n", + " each feature: all features must be aligned on the fact they `want` to predict a null answer).\n", + " Only useful when :obj:`version_2_with_negative` is :obj:`True`.\n", + " output_dir (:obj:`str`, `optional`):\n", + " ์•„๋ž˜์˜ ๊ฐ’์ด ์ €์žฅ๋˜๋Š” ๊ฒฝ๋กœ\n", + " dictionary : predictions, n_best predictions (with their scores and logits) if:obj:`version_2_with_negative=True`,\n", + " dictionary : the scores differences between best and null answers\n", + " prefix (:obj:`str`, `optional`):\n", + " dictionary์— `prefix`๊ฐ€ ํฌํ•จ๋˜์–ด ์ €์žฅ๋จ\n", + " is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):\n", + " ์ด ํ”„๋กœ์„ธ์Šค๊ฐ€ main process์ธ์ง€ ์—ฌ๋ถ€(logging/save๋ฅผ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•˜๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋จ)\n", + " \"\"\"\n", + " assert (\n", + " len(predictions) == 2\n", + " ), \"`predictions` should be a tuple with two elements (start_logits, end_logits).\"\n", + " all_start_logits, all_end_logits = predictions\n", + "\n", + " assert len(predictions[0]) == len(\n", + " features\n", + " ), f\"Got {len(predictions[0])} predictions and {len(features)} features.\"\n", + "\n", + " # example๊ณผ mapping๋˜๋Š” feature ์ƒ์„ฑ\n", + " example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n", + " features_per_example = collections.defaultdict(list)\n", + " for i, feature in enumerate(features):\n", + " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\n", + "\n", + " # prediction, nbest์— ํ•ด๋‹นํ•˜๋Š” OrderedDict ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n", + " all_predictions = collections.OrderedDict()\n", + " all_nbest_json = collections.OrderedDict()\n", + " if version_2_with_negative:\n", + " scores_diff_json = collections.OrderedDict()\n", + "\n", + " # Logging.\n", + " logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)\n", + " logger.info(\n", + " f\"Post-processing {len(examples)} example predictions split into {len(features)} features.\"\n", + " )\n", + "\n", + " # ์ „์ฒด example๋“ค์— ๋Œ€ํ•œ main Loop\n", + " for example_index, example in enumerate(tqdm(examples)):\n", + " # ํ•ด๋‹นํ•˜๋Š” ํ˜„์žฌ example index\n", + " feature_indices = features_per_example[example_index]\n", + "\n", + " min_null_prediction = None\n", + " prelim_predictions = []\n", + "\n", + " # ํ˜„์žฌ example์— ๋Œ€ํ•œ ๋ชจ๋“  feature ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n", + " for feature_index in feature_indices:\n", + " # ๊ฐ featureure์— ๋Œ€ํ•œ ๋ชจ๋“  prediction์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.\n", + " start_logits = all_start_logits[feature_index]\n", + " end_logits = all_end_logits[feature_index]\n", + " # logit๊ณผ original context์˜ logit์„ mappingํ•ฉ๋‹ˆ๋‹ค.\n", + " offset_mapping = features[feature_index][\"offset_mapping\"]\n", + " # Optional : `token_is_max_context`, ์ œ๊ณต๋˜๋Š” ๊ฒฝ์šฐ ํ˜„์žฌ ๊ธฐ๋Šฅ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” max context๊ฐ€ ์—†๋Š” answer๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค\n", + " token_is_max_context = features[feature_index].get(\n", + " \"token_is_max_context\", None\n", + " )\n", + "\n", + " # minimum null prediction์„ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค.\n", + " feature_null_score = start_logits[0] + end_logits[0]\n", + " if (\n", + " min_null_prediction is None\n", + " or min_null_prediction[\"score\"] > feature_null_score\n", + " ):\n", + " min_null_prediction = {\n", + " \"offsets\": (0, 0),\n", + " \"score\": feature_null_score,\n", + " \"start_logit\": start_logits[0],\n", + " \"end_logit\": end_logits[0],\n", + " }\n", + "\n", + " # `n_best_size`๋ณด๋‹ค ํฐ start and end logits์„ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค.\n", + " start_indexes = np.argsort(start_logits)[\n", + " -1 : -n_best_size - 1 : -1\n", + " ].tolist()\n", + "\n", + " end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", + "\n", + " for start_index in start_indexes:\n", + " for end_index in end_indexes:\n", + " # out-of-scope answers๋Š” ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n", + " if (\n", + " start_index >= len(offset_mapping)\n", + " or end_index >= len(offset_mapping)\n", + " or offset_mapping[start_index] is None\n", + " or offset_mapping[end_index] is None\n", + " ):\n", + " continue\n", + " # ๊ธธ์ด๊ฐ€ < 0 ๋˜๋Š” > max_answer_length์ธ answer๋„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n", + " if (\n", + " end_index < start_index\n", + " or end_index - start_index + 1 > max_answer_length\n", + " ):\n", + " continue\n", + " # ์ตœ๋Œ€ context๊ฐ€ ์—†๋Š” answer๋„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n", + " if (\n", + " token_is_max_context is not None\n", + " and not token_is_max_context.get(str(start_index), False)\n", + " ):\n", + " continue\n", + " prelim_predictions.append(\n", + " {\n", + " \"offsets\": (\n", + " offset_mapping[start_index][0],\n", + " offset_mapping[end_index][1],\n", + " ),\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\n", + " \"start_logit\": start_logits[start_index],\n", + " \"end_logit\": end_logits[end_index],\n", + " }\n", + " )\n", + "\n", + " if version_2_with_negative:\n", + " # minimum null prediction์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.\n", + " prelim_predictions.append(min_null_prediction)\n", + " null_score = min_null_prediction[\"score\"]\n", + "\n", + " # ๊ฐ€์žฅ ์ข‹์€ `n_best_size` predictions๋งŒ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.\n", + " predictions = sorted(\n", + " prelim_predictions, key=lambda x: x[\"score\"], reverse=True\n", + " )[:n_best_size]\n", + "\n", + " # ๋‚ฎ์€ ์ ์ˆ˜๋กœ ์ธํ•ด ์ œ๊ฑฐ๋œ ๊ฒฝ์šฐ minimum null prediction์„ ๋‹ค์‹œ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.\n", + " if version_2_with_negative and not any(\n", + " p[\"offsets\"] == (0, 0) for p in predictions\n", + " ):\n", + " predictions.append(min_null_prediction)\n", + "\n", + " # offset์„ ์‚ฌ์šฉํ•˜์—ฌ original context์—์„œ answer text๋ฅผ ์ˆ˜์ง‘ํ•ฉ๋‹ˆ๋‹ค.\n", + " context = example[\"context\"]\n", + " for pred in predictions:\n", + " offsets = pred.pop(\"offsets\")\n", + " pred[\"text\"] = context[offsets[0] : offsets[1]]\n", + "\n", + " # rare edge case์—๋Š” null์ด ์•„๋‹Œ ์˜ˆ์ธก์ด ํ•˜๋‚˜๋„ ์—†์œผ๋ฉฐ failure๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด fake prediction์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.\n", + " if len(predictions) == 0 or (\n", + " len(predictions) == 1 and predictions[0][\"text\"] == \"\"\n", + " ):\n", + "\n", + " predictions.insert(\n", + " 0, {\"text\": \"empty\", \"start_logit\": 0.0, \"end_logit\": 0.0, \"score\": 0.0}\n", + " )\n", + "\n", + " # ๋ชจ๋“  ์ ์ˆ˜์˜ ์†Œํ”„ํŠธ๋งฅ์Šค๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค(we do it with numpy to stay independent from torch/tf in this file, using the LogSumExp trick).\n", + " scores = np.array([pred.pop(\"score\") for pred in predictions])\n", + " exp_scores = np.exp(scores - np.max(scores))\n", + " probs = exp_scores / exp_scores.sum()\n", + "\n", + " # ์˜ˆ์ธก๊ฐ’์— ํ™•๋ฅ ์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.\n", + " for prob, pred in zip(probs, predictions):\n", + " pred[\"probability\"] = prob\n", + "\n", + " # best prediction์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.\n", + " if not version_2_with_negative:\n", + " all_predictions[example[\"id\"]] = predictions[0][\"text\"]\n", + " else:\n", + " # else case : ๋จผ์ € ๋น„์–ด ์žˆ์ง€ ์•Š์€ ์ตœ์ƒ์˜ ์˜ˆ์ธก์„ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค\n", + " i = 0\n", + " while predictions[i][\"text\"] == \"\":\n", + " i += 1\n", + " best_non_null_pred = predictions[i]\n", + "\n", + " # threshold๋ฅผ ์‚ฌ์šฉํ•ด์„œ null prediction์„ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.\n", + " score_diff = (\n", + " null_score\n", + " - best_non_null_pred[\"start_logit\"]\n", + " - best_non_null_pred[\"end_logit\"]\n", + " )\n", + " scores_diff_json[example[\"id\"]] = float(score_diff) # JSON-serializable ๊ฐ€๋Šฅ\n", + " if score_diff > null_score_diff_threshold:\n", + " all_predictions[example[\"id\"]] = \"\"\n", + " else:\n", + " all_predictions[example[\"id\"]] = best_non_null_pred[\"text\"]\n", + "\n", + " # np.float๋ฅผ ๋‹ค์‹œ float๋กœ casting -> `predictions`์€ JSON-serializable ๊ฐ€๋Šฅ\n", + " all_nbest_json[example[\"id\"]] = [\n", + " {\n", + " k: (\n", + " float(v)\n", + " if isinstance(v, (np.float16, np.float32, np.float64))\n", + " else v\n", + " )\n", + " for k, v in pred.items()\n", + " }\n", + " for pred in predictions\n", + " ]\n", + "\n", + " # output_dir์ด ์žˆ์œผ๋ฉด ๋ชจ๋“  dicts๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.\n", + " if output_dir is not None:\n", + " assert os.path.isdir(output_dir), f\"{output_dir} is not a directory.\"\n", + "\n", + " prediction_file = os.path.join(\n", + " output_dir,\n", + " \"predictions.json\" if prefix is None else f\"predictions_{prefix}.json\",\n", + " )\n", + " nbest_file = os.path.join(\n", + " output_dir,\n", + " \"nbest_predictions.json\"\n", + " if prefix is None\n", + " else f\"nbest_predictions_{prefix}.json\",\n", + " )\n", + " if version_2_with_negative:\n", + " null_odds_file = os.path.join(\n", + " output_dir,\n", + " \"null_odds.json\" if prefix is None else f\"null_odds_{prefix}.json\",\n", + " )\n", + "\n", + " logger.info(f\"Saving predictions to {prediction_file}.\")\n", + " with open(prediction_file, \"w\", encoding=\"utf-8\") as writer:\n", + " writer.write(\n", + " json.dumps(all_predictions, indent=4, ensure_ascii=False) + \"\\n\"\n", + " )\n", + " logger.info(f\"Saving nbest_preds to {nbest_file}.\")\n", + " with open(nbest_file, \"w\", encoding=\"utf-8\") as writer:\n", + " writer.write(\n", + " json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + \"\\n\"\n", + " )\n", + " if version_2_with_negative:\n", + " logger.info(f\"Saving null_odds to {null_odds_file}.\")\n", + " with open(null_odds_file, \"w\", encoding=\"utf-8\") as writer:\n", + " writer.write(\n", + " json.dumps(scores_diff_json, indent=4, ensure_ascii=False) + \"\\n\"\n", + " )\n", + "\n", + " return all_predictions\n", + "\n", + "\n", + "def check_no_error(\n", + " data_args: DataTrainingArguments,\n", + " training_args: TrainingArguments,\n", + " datasets: DatasetDict,\n", + " tokenizer,\n", + ") -> Tuple[Any, int]:\n", + "\n", + " # last checkpoint ์ฐพ๊ธฐ.\n", + " last_checkpoint = None\n", + " if (\n", + " os.path.isdir(training_args.output_dir)\n", + " and training_args.do_train\n", + " and not training_args.overwrite_output_dir\n", + " ):\n", + " last_checkpoint = get_last_checkpoint(training_args.output_dir)\n", + " if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:\n", + " raise ValueError(\n", + " f\"Output directory ({training_args.output_dir}) already exists and is not empty. \"\n", + " \"Use --overwrite_output_dir to overcome.\"\n", + " )\n", + " elif last_checkpoint is not None:\n", + " logger.info(\n", + " f\"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change \"\n", + " \"the `--output_dir` or add `--overwrite_output_dir` to train from scratch.\"\n", + " )\n", + "\n", + " # Tokenizer check: ํ•ด๋‹น script๋Š” Fast tokenizer๋ฅผ ํ•„์š”๋กœํ•ฉ๋‹ˆ๋‹ค.\n", + " if not isinstance(tokenizer, PreTrainedTokenizerFast):\n", + " raise ValueError(\n", + " \"This example script only works for models that have a fast tokenizer. Checkout the big table of models \"\n", + " \"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this \"\n", + " \"requirement\"\n", + " )\n", + "\n", + " if data_args.max_seq_length > tokenizer.model_max_length:\n", + " logger.warn(\n", + " f\"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the\"\n", + " f\"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.\"\n", + " )\n", + " max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)\n", + "\n", + " if \"validation\" not in datasets:\n", + " raise ValueError(\"--do_eval requires a validation dataset\")\n", + " return last_checkpoint, max_seq_length\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ํ›„์ฒ˜๋ฆฌ ํ•จ์ˆ˜ ์ •์˜\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ชจ๋ธ์ด ์ดํ•ดํ•˜๋Š” ํ˜•ํƒœ์—์„œ ์‚ฌ๋žŒ์ด ์ดํ•ดํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋‹ต๋ณ€ ๋งค์นญ\n", + "def post_processing_function(examples, features, predictions):\n", + " # Post-processing: we match the start logits and end logits to answers in the original context.\n", + " predictions = postprocess_qa_predictions(\n", + " examples=examples,\n", + " features=features,\n", + " predictions=predictions,\n", + " version_2_with_negative=False,\n", + " n_best_size=n_best_size,\n", + " max_answer_length=max_answer_length,\n", + " null_score_diff_threshold=0.0,\n", + " output_dir=training_args.output_dir,\n", + " is_world_process_zero=trainer.is_world_process_zero(),\n", + " )\n", + "\n", + " # Format the result to the format the metric expects.\n", + " formatted_predictions = [{\"id\": k, \"prediction_text\": v} for k, v in predictions.items()]\n", + " references = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in datasets[\"validation\"]]\n", + " return EvalPrediction(predictions=formatted_predictions, label_ids=references)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(p: EvalPrediction):\n", + " return metric.compute(predictions=p.predictions, references=p.label_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train !\n" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.18.3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /data/ephemeral/home/level2-mrc-nlp-15/src/wandb/run-20241017_040517-ua6lvn4p" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run dandy-rain-976 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/nlp15/odqa" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/nlp15/odqa/runs/ua6lvn4p" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import wandb\n", + "\n", + "wandb.login()\n", + "wandb.init(project='odqa', # ์‹คํ—˜๊ธฐ๋ก์„ ๊ด€๋ฆฌํ•œ ํ”„๋กœ์ ํŠธ ์ด๋ฆ„\n", + " entity='nlp15', # ์‚ฌ์šฉ์ž๋ช… ๋˜๋Š” ํŒ€ ์ด๋ฆ„ \n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"outputs\",\n", + " do_train=True,\n", + " do_eval=True,\n", + " per_device_train_batch_size=batch_size,\n", + " per_device_eval_batch_size=batch_size,\n", + " num_train_epochs=num_train_epochs,\n", + " save_total_limit=1, # ์ €์žฅํ•  ์ฒดํฌํฌ์ธํŠธ์˜ ์ตœ๋Œ€ ์ˆ˜\n", + " #evaluation_strategy=\"steps\",\n", + " #eval_steps=500, # ๋ช‡ ์Šคํ…๋งˆ๋‹ค ํ‰๊ฐ€ํ• ์ง€ ์„ค์ •\n", + " #logging_steps=500, # ๋ช‡ ์Šคํ…๋งˆ๋‹ค ๋กœ๊น…ํ• ์ง€ ์„ค์ •,\n", + " #load_best_model_at_end=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = QuestionAnsweringTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " eval_examples=val_dataset,\n", + " tokenizer=tokenizer,\n", + " data_collator=default_data_collator, # ๋ณดํ†ต default\n", + " post_process_function=post_processing_function, # function์„ input์œผ๋กœ ๋ฐ›์Œ!\n", + " compute_metrics=compute_metrics\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n", + "***** Running training *****\n", + " Num examples = 62641\n", + " Num Epochs = 1\n", + " Instantaneous batch size per device = 16\n", + " Total train batch size (w. parallel, distributed & accumulation) = 16\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 3916\n", + " Number of trainable parameters = 356600834\n", + "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [3916/3916 1:43:53, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
5000.769300
10000.541500
15000.523800
20000.483600
25000.410800
30000.379300
35000.350300

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to outputs/checkpoint-500\n", + "Configuration saved in outputs/checkpoint-500/config.json\n", + "Model weights saved in outputs/checkpoint-500/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-500/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-500/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-3500] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-1000\n", + "Configuration saved in outputs/checkpoint-1000/config.json\n", + "Model weights saved in outputs/checkpoint-1000/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-1000/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-1000/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-500] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-1500\n", + "Configuration saved in outputs/checkpoint-1500/config.json\n", + "Model weights saved in outputs/checkpoint-1500/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-1500/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-1500/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-1000] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-2000\n", + "Configuration saved in outputs/checkpoint-2000/config.json\n", + "Model weights saved in outputs/checkpoint-2000/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-2000/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-2000/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-1500] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-2500\n", + "Configuration saved in outputs/checkpoint-2500/config.json\n", + "Model weights saved in outputs/checkpoint-2500/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-2500/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-2500/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-2000] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-3000\n", + "Configuration saved in outputs/checkpoint-3000/config.json\n", + "Model weights saved in outputs/checkpoint-3000/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-3000/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-3000/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-2500] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-3500\n", + "Configuration saved in outputs/checkpoint-3500/config.json\n", + "Model weights saved in outputs/checkpoint-3500/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-3500/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-3500/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-3000] due to args.save_total_limit\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + } + ], + "source": [ + "train_result = trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

Run history:


train/epochโ–โ–‚โ–ƒโ–„โ–…โ–†โ–‡โ–ˆ
train/global_stepโ–โ–‚โ–ƒโ–„โ–…โ–†โ–‡โ–ˆ
train/learning_rateโ–ˆโ–‡โ–†โ–„โ–ƒโ–‚โ–
train/lossโ–ˆโ–„โ–„โ–ƒโ–‚โ–โ–
train/total_flosโ–
train/train_lossโ–
train/train_runtimeโ–
train/train_samples_per_secondโ–
train/train_steps_per_secondโ–

Run summary:


train/epoch1
train/global_step3916
train/learning_rate1e-05
train/loss0.3503
train/total_flos6.221469142067405e+16
train/train_loss0.47893
train/train_runtime6235.4731
train/train_samples_per_second10.046
train/train_steps_per_second0.628

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run dandy-rain-976 at: https://wandb.ai/nlp15/odqa/runs/ua6lvn4p
View project at: https://wandb.ai/nlp15/odqa
Synced 4 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20241017_040517-ua6lvn4p/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CNN_RobertaForQuestionAnswering(\n", + " (roberta): RobertaModel(\n", + " (embeddings): RobertaEmbeddings(\n", + " (word_embeddings): Embedding(32000, 1024, padding_idx=1)\n", + " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n", + " (token_type_embeddings): Embedding(1, 1024)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): RobertaEncoder(\n", + " (layer): ModuleList(\n", + " (0-23): 24 x RobertaLayer(\n", + " (attention): RobertaAttention(\n", + " (self): RobertaSelfAttention(\n", + " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): RobertaSelfOutput(\n", + " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): RobertaIntermediate(\n", + " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): RobertaOutput(\n", + " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (cnn_block1): CNN_block(\n", + " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block2): CNN_block(\n", + " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block3): CNN_block(\n", + " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block4): CNN_block(\n", + " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block5): CNN_block(\n", + " (conv1): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ํ—ˆ๊น…ํŽ˜์ด์Šค์— ๋ชจ๋ธ ์—…๋กœ๋“œ" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "/bin/bash: sudo: command not found\n" + ] + } + ], + "source": [ + "!sudo apt-get install git-lfs" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Configuration saved in /tmp/tmp_r5qc71i/config.json\n", + "Model weights saved in /tmp/tmp_r5qc71i/pytorch_model.bin\n", + "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v1: pytorch_model.bin,config.json\n", + "pytorch_model.bin: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.43G/1.43G [00:48<00:00, 29.1MB/s] \n", + "tokenizer config file saved in /tmp/tmpxrg7mu12/tokenizer_config.json\n", + "Special tokens file saved in /tmp/tmpxrg7mu12/special_tokens_map.json\n", + "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v1: tokenizer_config.json,vocab.txt,tokenizer.json,special_tokens_map.json\n" + ] + }, + { + "data": { + "text/plain": [ + "CommitInfo(commit_url='https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v1/commit/bd3c260ee793ce46ab444055515bf02be74f2a80', commit_message='Upload tokenizer', commit_description='', oid='bd3c260ee793ce46ab444055515bf02be74f2a80', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v1', endpoint='https://huggingface.co', repo_type='model', repo_id='ssunbear/klue_roberta_large_finetuned_korquad_v1'), pr_revision=None, pr_num=None)" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mํ˜„์žฌ ์…€ ๋˜๋Š” ์ด์ „ ์…€์—์„œ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๋™์•ˆ Kernel์ด ์ถฉ๋Œํ–ˆ์Šต๋‹ˆ๋‹ค. \n", + "\u001b[1;31m์…€์˜ ์ฝ”๋“œ๋ฅผ ๊ฒ€ํ† ํ•˜์—ฌ ๊ฐ€๋Šฅํ•œ ์˜ค๋ฅ˜ ์›์ธ์„ ์‹๋ณ„ํ•˜์„ธ์š”. \n", + "\u001b[1;31m์ž์„ธํ•œ ๋‚ด์šฉ์„ ๋ณด๋ ค๋ฉด ์—ฌ๊ธฐ๋ฅผ ํด๋ฆญํ•˜์„ธ์š”. \n", + "\u001b[1;31m์ž์„ธํ•œ ๋‚ด์šฉ์€ Jupyter ๋กœ๊ทธ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”." + ] + } + ], + "source": [ + "from transformers import AutoModel\n", + "from transformers import AutoTokenizer\n", + "\n", + "\n", + "\n", + "# Huggingface Access Token\n", + "ACCESS_TOKEN = # ํ† ํฐ์•„์ด๋”” ์ž…๋ ฅํ•˜์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค\n", + "\n", + "# Upload to Huggingface\n", + "model.push_to_hub('klue_roberta_large_finetuned_korquad_v1', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n", + "tokenizer.push_to_hub('klue_roberta_large_finetuned_korquad_v1', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Method 2\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Method2. Method1์— mrc_train ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•œ๋ฒˆ ๋” fine-tuning(๋ชจ๋ธ ์žฌํ˜ธ์ถœ) -> ssunbear/klue_roberta_large_finetuned_korquad_v2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-trained ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "loading configuration file config.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/config.json\n", + "Model config RobertaConfig {\n", + " \"_name_or_path\": \"ssunbear/klue_roberta_large_finetuned_korquad_v1\",\n", + " \"architectures\": [\n", + " \"RobertaForQuestionAnswering\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"bos_token_id\": 0,\n", + " \"classifier_dropout\": null,\n", + " \"eos_token_id\": 2,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 1024,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 4096,\n", + " \"layer_norm_eps\": 1e-05,\n", + " \"max_position_embeddings\": 514,\n", + " \"model_type\": \"roberta\",\n", + " \"num_attention_heads\": 16,\n", + " \"num_hidden_layers\": 24,\n", + " \"pad_token_id\": 1,\n", + " \"position_embedding_type\": \"absolute\",\n", + " \"tokenizer_class\": \"BertTokenizer\",\n", + " \"torch_dtype\": \"float32\",\n", + " \"transformers_version\": \"4.24.0\",\n", + " \"type_vocab_size\": 1,\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 32000\n", + "}\n", + "\n", + "loading file vocab.txt from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/vocab.txt\n", + "loading file tokenizer.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/tokenizer.json\n", + "loading file added_tokens.json from cache at None\n", + "loading file special_tokens_map.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/special_tokens_map.json\n", + "loading file tokenizer_config.json from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/tokenizer_config.json\n", + "loading weights file pytorch_model.bin from cache at /data/ephemeral/home/.cache/huggingface/hub/models--ssunbear--klue_roberta_large_finetuned_korquad_v1/snapshots/0ebea4e740f7702b26666664e61deaca4f7cb0dc/pytorch_model.bin\n", + "All model checkpoint weights were used when initializing RobertaForQuestionAnswering.\n", + "\n", + "All the weights of RobertaForQuestionAnswering were initialized from the model checkpoint at ssunbear/klue_roberta_large_finetuned_korquad_v1.\n", + "If your task is similar to the task the model of the checkpoint was trained on, you can already use RobertaForQuestionAnswering for predictions without further training.\n" + ] + } + ], + "source": [ + "from transformers import (\n", + " AutoConfig,\n", + " AutoModelForQuestionAnswering,\n", + " AutoTokenizer\n", + ")\n", + "\n", + "model_name = \"ssunbear/klue_roberta_large_finetuned_korquad_v1\"\n", + "\n", + "config = AutoConfig.from_pretrained(\n", + " model_name\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " model_name,\n", + " use_fast=True\n", + ")\n", + "model = AutoModelForQuestionAnswering.from_pretrained(\n", + " model_name,\n", + " config=config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "bat" + } + }, + "source": [ + "## Train ๋ฐ์ดํ„ฐ์…‹ ์ „์ฒ˜๋ฆฌ" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "max_seq_length = 512 # ์งˆ๋ฌธ๊ณผ ์ปจํ…์ŠคํŠธ, special token์„ ํ•ฉํ•œ ๋ฌธ์ž์—ด์˜ ์ตœ๋Œ€ ๊ธธ์ด (์ผ์ • ๊ฐœ์ˆ˜๊ฐ€ ๋„˜์–ด๊ฐ€์ง€ ์•Š๋„๋ก!)\n", + "pad_to_max_length = False\n", + "doc_stride = 128 # ์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์„œ ๋‚˜๋ˆด์„ ๋•Œ ์˜ค๋ฒ„๋žฉ๋˜๋Š” ์‹œํ€€์Šค ๊ธธ์ด, ๋ฌธ์„œ 2๊ฐœ๋กœ ์ชผ๊ฐœ๊ณ , 128๊ฐœ ์‹œํ€€์Šค๊ฐ€ ๊ฒน์น˜๋„๋ก\n", + "preprocessing_num_workers = None\n", + "batch_size = 16\n", + "num_train_epochs = 4\n", + "n_best_size = 20\n", + "max_answer_length = 30\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_train_features(examples): # examples: ๋ฐ์ดํ„ฐ์…‹ row..\n", + " # ์ฃผ์–ด์ง„ ํ…์ŠคํŠธ๋ฅผ ํ† ํฌ๋‚˜์ด์ง• ํ•œ๋‹ค. ์ด ๋•Œ ํ…์ŠคํŠธ์˜ ๊ธธ์ด๊ฐ€ max_seq_length๋ฅผ ๋„˜์œผ๋ฉด stride๋งŒํผ ์Šฌ๋ผ์ด๋”ฉํ•˜๋ฉฐ ์—ฌ๋Ÿฌ ๊ฐœ๋กœ ์ชผ๊ฐฌ.\n", + " # ์ฆ‰, ํ•˜๋‚˜์˜ example์—์„œ ์ผ๋ถ€๋ถ„์ด ๊ฒน์น˜๋Š” ์—ฌ๋Ÿฌ sequence(feature)๊ฐ€ ์ƒ๊ธธ ์ˆ˜ ์žˆ์Œ.\n", + " \n", + " question_column_name = \"question\" if \"question\" in column_names else column_names[0]\n", + " context_column_name = \"context\" if \"context\" in column_names else column_names[1]\n", + " answer_column_name = \"answers\" if \"answers\" in column_names else column_names[2]\n", + "\n", + " # Padding์— ๋Œ€ํ•œ ์˜ต์…˜์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.\n", + " tokenized_examples = tokenizer(\n", + " examples[\"question\"],\n", + " examples[\"context\"],\n", + " truncation=\"only_second\", # max_seq_length๊นŒ์ง€ truncateํ•œ๋‹ค. pair์˜ ๋‘๋ฒˆ์งธ ํŒŒํŠธ(context)๋งŒ ์ž˜๋ผ๋ƒ„.\n", + " max_length=max_seq_length,\n", + " stride=doc_stride,\n", + " return_overflowing_tokens=True, # ๊ธธ์ด๋ฅผ ๋„˜์–ด๊ฐ€๋Š” ํ† ํฐ๋“ค์„ ๋ฐ˜ํ™˜ํ•  ๊ฒƒ์ธ์ง€\n", + " return_offsets_mapping=True, # ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด (char_start, char_end) ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜ํ•œ ๊ฒƒ์ธ์ง€\n", + " padding=\"max_length\", return_token_type_ids=False\n", + " )\n", + "\n", + " # example ํ•˜๋‚˜๊ฐ€ ์—ฌ๋Ÿฌ sequence์— ๋Œ€์‘ํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์œ„ํ•ด ๋งคํ•‘์ด ํ•„์š”ํ•จ.\n", + " overflow_to_sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n", + " # offset_mappings์œผ๋กœ ํ† ํฐ์ด ์›๋ณธ context ๋‚ด ๋ช‡๋ฒˆ์งธ ๊ธ€์ž๋ถ€ํ„ฐ ๋ช‡๋ฒˆ์งธ ๊ธ€์ž๊นŒ์ง€ ํ•ด๋‹นํ•˜๋Š”์ง€ ์•Œ ์ˆ˜ ์žˆ์Œ.\n", + " offset_mapping = tokenized_examples.pop(\"offset_mapping\")\n", + "\n", + " # ์ •๋‹ต์ง€๋ฅผ ๋งŒ๋“ค๊ธฐ ์œ„ํ•œ ๋ฆฌ์ŠคํŠธ\n", + " tokenized_examples[\"start_positions\"] = []\n", + " tokenized_examples[\"end_positions\"] = []\n", + "\n", + " for i, offsets in enumerate(offset_mapping):\n", + " input_ids = tokenized_examples[\"input_ids\"][i]\n", + " cls_index = input_ids.index(tokenizer.cls_token_id)\n", + "\n", + " # ํ•ด๋‹น example์— ํ•ด๋‹นํ•˜๋Š” sequence๋ฅผ ์ฐพ์Œ.\n", + " sequence_ids = tokenized_examples.sequence_ids(i)\n", + "\n", + " # sequence๊ฐ€ ์†ํ•˜๋Š” example์„ ์ฐพ๋Š”๋‹ค.\n", + " example_index = overflow_to_sample_mapping[i]\n", + " answers = examples[\"answers\"][example_index]\n", + "\n", + " # ํ…์ŠคํŠธ์—์„œ answer์˜ ์‹œ์ž‘์ , ๋์ \n", + " answer_start_offset = answers[\"answer_start\"][0]\n", + " answer_end_offset = answer_start_offset + len(answers[\"text\"][0])\n", + "\n", + " # ํ…์ŠคํŠธ์—์„œ ํ˜„์žฌ span์˜ ์‹œ์ž‘ ํ† ํฐ ์ธ๋ฑ์Šค\n", + " token_start_index = 0\n", + " while sequence_ids[token_start_index] != 1:\n", + " token_start_index += 1\n", + "\n", + " # ํ…์ŠคํŠธ์—์„œ ํ˜„์žฌ span ๋ ํ† ํฐ ์ธ๋ฑ์Šค\n", + " token_end_index = len(input_ids) - 1\n", + " while sequence_ids[token_end_index] != 1:\n", + " token_end_index -= 1\n", + "\n", + " # answer๊ฐ€ ํ˜„์žฌ span์„ ๋ฒ—์–ด๋‚ฌ๋Š”์ง€ ์ฒดํฌ\n", + " if not (offsets[token_start_index][0] <= answer_start_offset and offsets[token_end_index][1] >= answer_end_offset):\n", + " tokenized_examples[\"start_positions\"].append(cls_index)\n", + " tokenized_examples[\"end_positions\"].append(cls_index)\n", + " else:\n", + " # token_start_index์™€ token_end_index๋ฅผ answer์˜ ์‹œ์ž‘์ ๊ณผ ๋์ ์œผ๋กœ ์˜ฎ๊น€\n", + " while token_start_index < len(offsets) and offsets[token_start_index][0] <= answer_start_offset:\n", + " token_start_index += 1\n", + " tokenized_examples[\"start_positions\"].append(token_start_index - 1)\n", + " while offsets[token_end_index][1] >= answer_end_offset:\n", + " token_end_index -= 1\n", + " tokenized_examples[\"end_positions\"].append(token_end_index + 1)\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3952/3952 [00:02<00:00, 1426.98 examples/s]\n" + ] + } + ], + "source": [ + "column_names = mrc_train_dataset.column_names\n", + "train_dataset = mrc_train_dataset.map(\n", + " prepare_train_features,\n", + " batched=True,\n", + " num_proc=preprocessing_num_workers,\n", + " remove_columns=column_names,\n", + " load_from_cache_file=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['title', 'context', 'question', 'id', 'answers', 'document_id', '__index_level_0__'],\n", + " num_rows: 240\n", + "})" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_from_disk\n", + "\n", + "# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ\n", + "mrc_validation_dataset_path = \"/data/ephemeral/home/level2-mrc-nlp-15/data/train_dataset/validation\" # ์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •\n", + "mrc_validation_dataset = load_from_disk(mrc_validation_dataset_path)\n", + "mrc_validation_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "# korquad ๋ฐ์ดํ„ฐ์…‹์ด๋ž‘ ํ˜•์‹ ๋˜‘๊ฐ™์ด ๋งŒ๋“ค์–ด์ฃผ๊ธฐ\n", + "id_list0 = []\n", + "title_list0 = []\n", + "context_list0 = []\n", + "question_list0 = []\n", + "answers_list0 = []" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "id_list0" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "for index, row in pd.DataFrame(mrc_validation_dataset).iterrows():\n", + " id_list0.append(row['id'])\n", + " title_list0.append(str(row['title']))\n", + " context_list0.append(str(row['context']))\n", + " question_list0.append(str(row['question']))\n", + " answers_list0.append(row['answers'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "mrc_validation_dataset = {\n", + " \"id\" : id_list0,\n", + " \"title\" : title_list0,\n", + " \"context\" : context_list0,\n", + " \"question\" : question_list0,\n", + " \"answers\" : answers_list0,}" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['id', 'title', 'context', 'question', 'answers'],\n", + " num_rows: 240\n", + "})" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import Dataset\n", + "\n", + "mrc_validation_dataset= Dataset.from_dict(mrc_validation_dataset)\n", + "mrc_validation_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_validation_features(examples):\n", + " tokenized_examples = tokenizer(\n", + " examples['question'],\n", + " examples['context'],\n", + " truncation=\"only_second\",\n", + " max_length=max_seq_length,\n", + " stride=doc_stride,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " )\n", + "\n", + " sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n", + "\n", + " tokenized_examples[\"example_id\"] = []\n", + "\n", + " for i in range(len(tokenized_examples[\"input_ids\"])):\n", + " sequence_ids = tokenized_examples.sequence_ids(i)\n", + " context_index = 1\n", + "\n", + " sample_index = sample_mapping[i]\n", + " tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\n", + "\n", + " tokenized_examples[\"offset_mapping\"][i] = [\n", + " (o if sequence_ids[k] == context_index else None)\n", + " for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\n", + " ]\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 240/240 [00:00<00:00, 846.63 examples/s]\n" + ] + } + ], + "source": [ + "eval_dataset = mrc_validation_dataset.map(\n", + " prepare_validation_features,\n", + " batched=True,\n", + " num_proc=preprocessing_num_workers,\n", + " remove_columns=column_names,\n", + " load_from_cache_file=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Question Answering Class ์ •์˜" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [], + "source": [ + "# default_data_collator: ์—ฌ๋Ÿฌ๊ฐœ example๋“ค์„ collatorํ•ด์ฃผ๋Š” ์—ญํ• ,\n", + "# TrainingArguments : ํ•œ๋ฒˆ์— training arguments๋“ค์„ ํ•ฉ์ณ์„œ ์ฃผ๋Š”..!\n", + "from transformers import default_data_collator, TrainingArguments, EvalPrediction" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "# coding=utf-8\n", + "# Copyright 2020 The HuggingFace Team All rights reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\"\"\"\n", + "Question-Answering task์™€ ๊ด€๋ จ๋œ 'Trainer'์˜ subclass ์ฝ”๋“œ ์ž…๋‹ˆ๋‹ค.\n", + "\"\"\"\n", + "\n", + "from transformers import Trainer, is_datasets_available, is_torch_tpu_available\n", + "from transformers.trainer_utils import PredictionOutput\n", + "\n", + "if is_datasets_available():\n", + " import datasets\n", + "\n", + "# Huggingface์˜ Trainer๋ฅผ ์ƒ์†๋ฐ›์•„ QuestionAnswering์„ ์œ„ํ•œ Trainer๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n", + "class QuestionAnsweringTrainer(Trainer):\n", + " def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.eval_examples = eval_examples\n", + " self.post_process_function = post_process_function\n", + "\n", + " def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):\n", + " eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset\n", + " eval_dataloader = self.get_eval_dataloader(eval_dataset)\n", + " eval_examples = self.eval_examples if eval_examples is None else eval_examples\n", + "\n", + " # ์ผ์‹œ์ ์œผ๋กœ metric computation๋ฅผ ๋ถˆ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ ์ƒํƒœ์ด๋ฉฐ, ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” loop ๋‚ด์—์„œ metric ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.\n", + " compute_metrics = self.compute_metrics\n", + " self.compute_metrics = None\n", + " try:\n", + " output = self.prediction_loop(\n", + " eval_dataloader,\n", + " description=\"Evaluation\",\n", + " # metric์ด ์—†์œผ๋ฉด ์˜ˆ์ธก๊ฐ’์„ ๋ชจ์œผ๋Š” ์ด์œ ๊ฐ€ ์—†์œผ๋ฏ€๋กœ ์•„๋ž˜์˜ ์ฝ”๋“œ๋ฅผ ๋”ฐ๋ฅด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n", + " # self.args.prediction_loss_only\n", + " prediction_loss_only=True if compute_metrics is None else None,\n", + " ignore_keys=ignore_keys,\n", + " )\n", + " finally:\n", + " self.compute_metrics = compute_metrics\n", + "\n", + " if isinstance(eval_dataset, datasets.Dataset):\n", + " eval_dataset.set_format(\n", + " type=eval_dataset.format[\"type\"],\n", + " columns=list(eval_dataset.features.keys()),\n", + " )\n", + "\n", + " if self.post_process_function is not None and self.compute_metrics is not None:\n", + " eval_preds = self.post_process_function(\n", + " eval_examples, eval_dataset, output.predictions, self.args\n", + " )\n", + " metrics = self.compute_metrics(eval_preds)\n", + "\n", + " self.log(metrics)\n", + " else:\n", + " metrics = {}\n", + "\n", + " self.control = self.callback_handler.on_evaluate(\n", + " self.args, self.state, self.control, metrics\n", + " )\n", + " return metrics\n", + "\n", + " def predict(self, test_dataset, test_examples, ignore_keys=None):\n", + " test_dataloader = self.get_test_dataloader(test_dataset)\n", + "\n", + " # ์ผ์‹œ์ ์œผ๋กœ metric computation๋ฅผ ๋ถˆ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ ์ƒํƒœ์ด๋ฉฐ, ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” loop ๋‚ด์—์„œ metric ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.\n", + " # evaluate ํ•จ์ˆ˜์™€ ๋™์ผํ•˜๊ฒŒ ๊ตฌ์„ฑ๋˜์–ด์žˆ์Šต๋‹ˆ๋‹ค\n", + " compute_metrics = self.compute_metrics\n", + " self.compute_metrics = None\n", + " try:\n", + " output = self.prediction_loop(\n", + " test_dataloader,\n", + " description=\"Evaluation\",\n", + " # metric์ด ์—†์œผ๋ฉด ์˜ˆ์ธก๊ฐ’์„ ๋ชจ์œผ๋Š” ์ด์œ ๊ฐ€ ์—†์œผ๋ฏ€๋กœ ์•„๋ž˜์˜ ์ฝ”๋“œ๋ฅผ ๋”ฐ๋ฅด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n", + " # self.args.prediction_loss_only\n", + " prediction_loss_only=True if compute_metrics is None else None,\n", + " ignore_keys=ignore_keys,\n", + " )\n", + " finally:\n", + " self.compute_metrics = compute_metrics\n", + "\n", + " if self.post_process_function is None or self.compute_metrics is None:\n", + " return output\n", + "\n", + " if isinstance(test_dataset, datasets.Dataset):\n", + " test_dataset.set_format(\n", + " type=test_dataset.format[\"type\"],\n", + " columns=list(test_dataset.features.keys()),\n", + " )\n", + "\n", + " predictions = self.post_process_function(\n", + " test_examples, test_dataset, output.predictions, self.args\n", + " )\n", + " return predictions\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ํ›„์ฒ˜๋ฆฌ ํด๋ž˜์Šค ์ •์˜" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "# coding=utf-8\n", + "# Copyright 2020 The HuggingFace Team All rights reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the 'License');\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an 'AS IS' BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\"\"\"\n", + "Pre-processing\n", + "Post-processing utilities for question answering.\n", + "\"\"\"\n", + "import collections\n", + "import json\n", + "import logging\n", + "import os\n", + "import random\n", + "from typing import Any, Optional, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from arguments import DataTrainingArguments, ModelArguments\n", + "from datasets import DatasetDict\n", + "from tqdm.auto import tqdm\n", + "from transformers import PreTrainedTokenizerFast, TrainingArguments, is_torch_available\n", + "from transformers.trainer_utils import get_last_checkpoint\n", + "\n", + "#from utils.datetime_helper import get_seoul_datetime_str\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "def set_seed(seed: int = 2024):\n", + " \"\"\"\n", + " seed ๊ณ ์ •ํ•˜๋Š” ํ•จ์ˆ˜ (random, numpy, torch)\n", + "\n", + " Args:\n", + " seed (:obj:`int`): The seed to set.\n", + " \"\"\"\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " if is_torch_available():\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed) # if use multi-GPU\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "\n", + "def postprocess_qa_predictions(\n", + " examples,\n", + " features,\n", + " predictions: Tuple[np.ndarray, np.ndarray],\n", + " version_2_with_negative: bool = False,\n", + " n_best_size: int = 20,\n", + " max_answer_length: int = 30,\n", + " null_score_diff_threshold: float = 0.0,\n", + " output_dir: Optional[str] = None,\n", + " prefix: Optional[str] = None,\n", + " is_world_process_zero: bool = True,\n", + "):\n", + " \"\"\"\n", + " Post-processes : qa model์˜ prediction ๊ฐ’์„ ํ›„์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜\n", + " ๋ชจ๋ธ์€ start logit๊ณผ end logit์„ ๋ฐ˜ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์ด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ original text๋กœ ๋ณ€๊ฒฝํ•˜๋Š” ํ›„์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•จ\n", + "\n", + " Args:\n", + " examples: ์ „์ฒ˜๋ฆฌ ๋˜์ง€ ์•Š์€ ๋ฐ์ดํ„ฐ์…‹ (see the main script for more information).\n", + " features: ์ „์ฒ˜๋ฆฌ๊ฐ€ ์ง„ํ–‰๋œ ๋ฐ์ดํ„ฐ์…‹ (see the main script for more information).\n", + " predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):\n", + " ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’ :start logits๊ณผ the end logits์„ ๋‚˜ํƒ€๋‚ด๋Š” two arrays ์ฒซ๋ฒˆ์งธ ์ฐจ์›์€ :obj:`features`์˜ element์™€ ๊ฐฏ์ˆ˜๊ฐ€ ๋งž์•„์•ผํ•จ.\n", + " version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):\n", + " ์ •๋‹ต์ด ์—†๋Š” ๋ฐ์ดํ„ฐ์…‹์ด ํฌํ•จ๋˜์–ด์žˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ๋‚˜ํƒ€๋ƒ„\n", + " n_best_size (:obj:`int`, `optional`, defaults to 20):\n", + " ๋‹ต๋ณ€์„ ์ฐพ์„ ๋•Œ ์ƒ์„ฑํ•  n-best prediction ์ด ๊ฐœ์ˆ˜\n", + " max_answer_length (:obj:`int`, `optional`, defaults to 30):\n", + " ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ๊ธธ์ด\n", + " null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):\n", + " null ๋‹ต๋ณ€์„ ์„ ํƒํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” threshold\n", + " : if the best answer has a score that is less than the score of\n", + " the null answer minus this threshold, the null answer is selected for this example (note that the score of\n", + " the null answer for an example giving several features is the minimum of the scores for the null answer on\n", + " each feature: all features must be aligned on the fact they `want` to predict a null answer).\n", + " Only useful when :obj:`version_2_with_negative` is :obj:`True`.\n", + " output_dir (:obj:`str`, `optional`):\n", + " ์•„๋ž˜์˜ ๊ฐ’์ด ์ €์žฅ๋˜๋Š” ๊ฒฝ๋กœ\n", + " dictionary : predictions, n_best predictions (with their scores and logits) if:obj:`version_2_with_negative=True`,\n", + " dictionary : the scores differences between best and null answers\n", + " prefix (:obj:`str`, `optional`):\n", + " dictionary์— `prefix`๊ฐ€ ํฌํ•จ๋˜์–ด ์ €์žฅ๋จ\n", + " is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):\n", + " ์ด ํ”„๋กœ์„ธ์Šค๊ฐ€ main process์ธ์ง€ ์—ฌ๋ถ€(logging/save๋ฅผ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•˜๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋จ)\n", + " \"\"\"\n", + " assert (\n", + " len(predictions) == 2\n", + " ), \"`predictions` should be a tuple with two elements (start_logits, end_logits).\"\n", + " all_start_logits, all_end_logits = predictions\n", + "\n", + " assert len(predictions[0]) == len(\n", + " features\n", + " ), f\"Got {len(predictions[0])} predictions and {len(features)} features.\"\n", + "\n", + " # example๊ณผ mapping๋˜๋Š” feature ์ƒ์„ฑ\n", + " example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n", + " features_per_example = collections.defaultdict(list)\n", + " for i, feature in enumerate(features):\n", + " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\n", + "\n", + " # prediction, nbest์— ํ•ด๋‹นํ•˜๋Š” OrderedDict ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n", + " all_predictions = collections.OrderedDict()\n", + " all_nbest_json = collections.OrderedDict()\n", + " if version_2_with_negative:\n", + " scores_diff_json = collections.OrderedDict()\n", + "\n", + " # Logging.\n", + " logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)\n", + " logger.info(\n", + " f\"Post-processing {len(examples)} example predictions split into {len(features)} features.\"\n", + " )\n", + "\n", + " # ์ „์ฒด example๋“ค์— ๋Œ€ํ•œ main Loop\n", + " for example_index, example in enumerate(tqdm(examples)):\n", + " # ํ•ด๋‹นํ•˜๋Š” ํ˜„์žฌ example index\n", + " feature_indices = features_per_example[example_index]\n", + "\n", + " min_null_prediction = None\n", + " prelim_predictions = []\n", + "\n", + " # ํ˜„์žฌ example์— ๋Œ€ํ•œ ๋ชจ๋“  feature ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n", + " for feature_index in feature_indices:\n", + " # ๊ฐ featureure์— ๋Œ€ํ•œ ๋ชจ๋“  prediction์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.\n", + " start_logits = all_start_logits[feature_index]\n", + " end_logits = all_end_logits[feature_index]\n", + " # logit๊ณผ original context์˜ logit์„ mappingํ•ฉ๋‹ˆ๋‹ค.\n", + " offset_mapping = features[feature_index][\"offset_mapping\"]\n", + " # Optional : `token_is_max_context`, ์ œ๊ณต๋˜๋Š” ๊ฒฝ์šฐ ํ˜„์žฌ ๊ธฐ๋Šฅ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” max context๊ฐ€ ์—†๋Š” answer๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค\n", + " token_is_max_context = features[feature_index].get(\n", + " \"token_is_max_context\", None\n", + " )\n", + "\n", + " # minimum null prediction์„ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค.\n", + " feature_null_score = start_logits[0] + end_logits[0]\n", + " if (\n", + " min_null_prediction is None\n", + " or min_null_prediction[\"score\"] > feature_null_score\n", + " ):\n", + " min_null_prediction = {\n", + " \"offsets\": (0, 0),\n", + " \"score\": feature_null_score,\n", + " \"start_logit\": start_logits[0],\n", + " \"end_logit\": end_logits[0],\n", + " }\n", + "\n", + " # `n_best_size`๋ณด๋‹ค ํฐ start and end logits์„ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค.\n", + " start_indexes = np.argsort(start_logits)[\n", + " -1 : -n_best_size - 1 : -1\n", + " ].tolist()\n", + "\n", + " end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", + "\n", + " for start_index in start_indexes:\n", + " for end_index in end_indexes:\n", + " # out-of-scope answers๋Š” ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n", + " if (\n", + " start_index >= len(offset_mapping)\n", + " or end_index >= len(offset_mapping)\n", + " or offset_mapping[start_index] is None\n", + " or offset_mapping[end_index] is None\n", + " ):\n", + " continue\n", + " # ๊ธธ์ด๊ฐ€ < 0 ๋˜๋Š” > max_answer_length์ธ answer๋„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n", + " if (\n", + " end_index < start_index\n", + " or end_index - start_index + 1 > max_answer_length\n", + " ):\n", + " continue\n", + " # ์ตœ๋Œ€ context๊ฐ€ ์—†๋Š” answer๋„ ๊ณ ๋ คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n", + " if (\n", + " token_is_max_context is not None\n", + " and not token_is_max_context.get(str(start_index), False)\n", + " ):\n", + " continue\n", + " prelim_predictions.append(\n", + " {\n", + " \"offsets\": (\n", + " offset_mapping[start_index][0],\n", + " offset_mapping[end_index][1],\n", + " ),\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\n", + " \"start_logit\": start_logits[start_index],\n", + " \"end_logit\": end_logits[end_index],\n", + " }\n", + " )\n", + "\n", + " if version_2_with_negative:\n", + " # minimum null prediction์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.\n", + " prelim_predictions.append(min_null_prediction)\n", + " null_score = min_null_prediction[\"score\"]\n", + "\n", + " # ๊ฐ€์žฅ ์ข‹์€ `n_best_size` predictions๋งŒ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.\n", + " predictions = sorted(\n", + " prelim_predictions, key=lambda x: x[\"score\"], reverse=True\n", + " )[:n_best_size]\n", + "\n", + " # ๋‚ฎ์€ ์ ์ˆ˜๋กœ ์ธํ•ด ์ œ๊ฑฐ๋œ ๊ฒฝ์šฐ minimum null prediction์„ ๋‹ค์‹œ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.\n", + " if version_2_with_negative and not any(\n", + " p[\"offsets\"] == (0, 0) for p in predictions\n", + " ):\n", + " predictions.append(min_null_prediction)\n", + "\n", + " # offset์„ ์‚ฌ์šฉํ•˜์—ฌ original context์—์„œ answer text๋ฅผ ์ˆ˜์ง‘ํ•ฉ๋‹ˆ๋‹ค.\n", + " context = example[\"context\"]\n", + " for pred in predictions:\n", + " offsets = pred.pop(\"offsets\")\n", + " pred[\"text\"] = context[offsets[0] : offsets[1]]\n", + "\n", + " # rare edge case์—๋Š” null์ด ์•„๋‹Œ ์˜ˆ์ธก์ด ํ•˜๋‚˜๋„ ์—†์œผ๋ฉฐ failure๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด fake prediction์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.\n", + " if len(predictions) == 0 or (\n", + " len(predictions) == 1 and predictions[0][\"text\"] == \"\"\n", + " ):\n", + "\n", + " predictions.insert(\n", + " 0, {\"text\": \"empty\", \"start_logit\": 0.0, \"end_logit\": 0.0, \"score\": 0.0}\n", + " )\n", + "\n", + " # ๋ชจ๋“  ์ ์ˆ˜์˜ ์†Œํ”„ํŠธ๋งฅ์Šค๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค(we do it with numpy to stay independent from torch/tf in this file, using the LogSumExp trick).\n", + " scores = np.array([pred.pop(\"score\") for pred in predictions])\n", + " exp_scores = np.exp(scores - np.max(scores))\n", + " probs = exp_scores / exp_scores.sum()\n", + "\n", + " # ์˜ˆ์ธก๊ฐ’์— ํ™•๋ฅ ์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.\n", + " for prob, pred in zip(probs, predictions):\n", + " pred[\"probability\"] = prob\n", + "\n", + " # best prediction์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.\n", + " if not version_2_with_negative:\n", + " all_predictions[example[\"id\"]] = predictions[0][\"text\"]\n", + " else:\n", + " # else case : ๋จผ์ € ๋น„์–ด ์žˆ์ง€ ์•Š์€ ์ตœ์ƒ์˜ ์˜ˆ์ธก์„ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค\n", + " i = 0\n", + " while predictions[i][\"text\"] == \"\":\n", + " i += 1\n", + " best_non_null_pred = predictions[i]\n", + "\n", + " # threshold๋ฅผ ์‚ฌ์šฉํ•ด์„œ null prediction์„ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.\n", + " score_diff = (\n", + " null_score\n", + " - best_non_null_pred[\"start_logit\"]\n", + " - best_non_null_pred[\"end_logit\"]\n", + " )\n", + " scores_diff_json[example[\"id\"]] = float(score_diff) # JSON-serializable ๊ฐ€๋Šฅ\n", + " if score_diff > null_score_diff_threshold:\n", + " all_predictions[example[\"id\"]] = \"\"\n", + " else:\n", + " all_predictions[example[\"id\"]] = best_non_null_pred[\"text\"]\n", + "\n", + " # np.float๋ฅผ ๋‹ค์‹œ float๋กœ casting -> `predictions`์€ JSON-serializable ๊ฐ€๋Šฅ\n", + " all_nbest_json[example[\"id\"]] = [\n", + " {\n", + " k: (\n", + " float(v)\n", + " if isinstance(v, (np.float16, np.float32, np.float64))\n", + " else v\n", + " )\n", + " for k, v in pred.items()\n", + " }\n", + " for pred in predictions\n", + " ]\n", + "\n", + " # output_dir์ด ์žˆ์œผ๋ฉด ๋ชจ๋“  dicts๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.\n", + " if output_dir is not None:\n", + " assert os.path.isdir(output_dir), f\"{output_dir} is not a directory.\"\n", + "\n", + " prediction_file = os.path.join(\n", + " output_dir,\n", + " \"predictions.json\" if prefix is None else f\"predictions_{prefix}.json\",\n", + " )\n", + " nbest_file = os.path.join(\n", + " output_dir,\n", + " \"nbest_predictions.json\"\n", + " if prefix is None\n", + " else f\"nbest_predictions_{prefix}.json\",\n", + " )\n", + " if version_2_with_negative:\n", + " null_odds_file = os.path.join(\n", + " output_dir,\n", + " \"null_odds.json\" if prefix is None else f\"null_odds_{prefix}.json\",\n", + " )\n", + "\n", + " logger.info(f\"Saving predictions to {prediction_file}.\")\n", + " with open(prediction_file, \"w\", encoding=\"utf-8\") as writer:\n", + " writer.write(\n", + " json.dumps(all_predictions, indent=4, ensure_ascii=False) + \"\\n\"\n", + " )\n", + " logger.info(f\"Saving nbest_preds to {nbest_file}.\")\n", + " with open(nbest_file, \"w\", encoding=\"utf-8\") as writer:\n", + " writer.write(\n", + " json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + \"\\n\"\n", + " )\n", + " if version_2_with_negative:\n", + " logger.info(f\"Saving null_odds to {null_odds_file}.\")\n", + " with open(null_odds_file, \"w\", encoding=\"utf-8\") as writer:\n", + " writer.write(\n", + " json.dumps(scores_diff_json, indent=4, ensure_ascii=False) + \"\\n\"\n", + " )\n", + "\n", + " return all_predictions\n", + "\n", + "\n", + "def check_no_error(\n", + " data_args: DataTrainingArguments,\n", + " training_args: TrainingArguments,\n", + " datasets: DatasetDict,\n", + " tokenizer,\n", + ") -> Tuple[Any, int]:\n", + "\n", + " # last checkpoint ์ฐพ๊ธฐ.\n", + " last_checkpoint = None\n", + " if (\n", + " os.path.isdir(training_args.output_dir)\n", + " and training_args.do_train\n", + " and not training_args.overwrite_output_dir\n", + " ):\n", + " last_checkpoint = get_last_checkpoint(training_args.output_dir)\n", + " if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:\n", + " raise ValueError(\n", + " f\"Output directory ({training_args.output_dir}) already exists and is not empty. \"\n", + " \"Use --overwrite_output_dir to overcome.\"\n", + " )\n", + " elif last_checkpoint is not None:\n", + " logger.info(\n", + " f\"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change \"\n", + " \"the `--output_dir` or add `--overwrite_output_dir` to train from scratch.\"\n", + " )\n", + "\n", + " # Tokenizer check: ํ•ด๋‹น script๋Š” Fast tokenizer๋ฅผ ํ•„์š”๋กœํ•ฉ๋‹ˆ๋‹ค.\n", + " if not isinstance(tokenizer, PreTrainedTokenizerFast):\n", + " raise ValueError(\n", + " \"This example script only works for models that have a fast tokenizer. Checkout the big table of models \"\n", + " \"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this \"\n", + " \"requirement\"\n", + " )\n", + "\n", + " if data_args.max_seq_length > tokenizer.model_max_length:\n", + " logger.warn(\n", + " f\"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the\"\n", + " f\"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.\"\n", + " )\n", + " max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)\n", + "\n", + " if \"validation\" not in datasets:\n", + " raise ValueError(\"--do_eval requires a validation dataset\")\n", + " return last_checkpoint, max_seq_length\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ํ›„์ฒ˜๋ฆฌ ํ•จ์ˆ˜ ์ •์˜" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ชจ๋ธ์ด ์ดํ•ดํ•˜๋Š” ํ˜•ํƒœ์—์„œ ์‚ฌ๋žŒ์ด ์ดํ•ดํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋‹ต๋ณ€ ๋งค์นญ\n", + "def post_processing_function(examples, features, predictions):\n", + " # Post-processing: we match the start logits and end logits to answers in the original context.\n", + " predictions = postprocess_qa_predictions(\n", + " examples=examples,\n", + " features=features,\n", + " predictions=predictions,\n", + " version_2_with_negative=False,\n", + " n_best_size=n_best_size,\n", + " max_answer_length=max_answer_length,\n", + " null_score_diff_threshold=0.0,\n", + " output_dir=training_args.output_dir,\n", + " is_world_process_zero=trainer.is_world_process_zero(),\n", + " )\n", + "\n", + " # Format the result to the format the metric expects.\n", + " formatted_predictions = [{\"id\": k, \"prediction_text\": v} for k, v in predictions.items()]\n", + " references = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in datasets[\"validation\"]]\n", + " return EvalPrediction(predictions=formatted_predictions, label_ids=references)" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(p: EvalPrediction):\n", + " return metric.compute(predictions=p.predictions, references=p.label_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train 2์ฐจ" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.18.3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /data/ephemeral/home/level2-mrc-nlp-15/src/wandb/run-20241015_184122-0y11f1su" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run electric-serenity-822 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/nlp15/odqa" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/nlp15/odqa/runs/0y11f1su" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import wandb\n", + "\n", + "wandb.login()\n", + "wandb.init(project='odqa', # ์‹คํ—˜๊ธฐ๋ก์„ ๊ด€๋ฆฌํ•œ ํ”„๋กœ์ ํŠธ ์ด๋ฆ„\n", + " entity='nlp15', # ์‚ฌ์šฉ์ž๋ช… ๋˜๋Š” ํŒ€ ์ด๋ฆ„ \n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PyTorch: setting up devices\n", + "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n" + ] + } + ], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"outputs\",\n", + " do_train=True,\n", + " do_eval=True,\n", + " per_device_train_batch_size=batch_size,\n", + " per_device_eval_batch_size=batch_size,\n", + " num_train_epochs=num_train_epochs,\n", + " save_total_limit=1, # ์ €์žฅํ•  ์ฒดํฌํฌ์ธํŠธ์˜ ์ตœ๋Œ€ ์ˆ˜\n", + " #evaluation_strategy=\"steps\",\n", + " #eval_steps=500, # ๋ช‡ ์Šคํ…๋งˆ๋‹ค ํ‰๊ฐ€ํ• ์ง€ ์„ค์ •\n", + " #logging_steps=500, # ๋ช‡ ์Šคํ…๋งˆ๋‹ค ๋กœ๊น…ํ• ์ง€ ์„ค์ •,\n", + " #load_best_model_at_end=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = QuestionAnsweringTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " eval_examples=val_dataset,\n", + " tokenizer=tokenizer,\n", + " data_collator=default_data_collator, # ๋ณดํ†ต default\n", + " post_process_function=post_processing_function, # function์„ input์œผ๋กœ ๋ฐ›์Œ!\n", + " compute_metrics=compute_metrics\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n", + "***** Running training *****\n", + " Num examples = 5769\n", + " Num Epochs = 4\n", + " Instantaneous batch size per device = 16\n", + " Total train batch size (w. parallel, distributed & accumulation) = 16\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 1444\n", + " Number of trainable parameters = 335608834\n", + "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [1444/1444 35:59, Epoch 4/4]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
5000.790100
10000.258400

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to outputs/checkpoint-500\n", + "Configuration saved in outputs/checkpoint-500/config.json\n", + "Model weights saved in outputs/checkpoint-500/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-500/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-500/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-3500] due to args.save_total_limit\n", + "Saving model checkpoint to outputs/checkpoint-1000\n", + "Configuration saved in outputs/checkpoint-1000/config.json\n", + "Model weights saved in outputs/checkpoint-1000/pytorch_model.bin\n", + "tokenizer config file saved in outputs/checkpoint-1000/tokenizer_config.json\n", + "Special tokens file saved in outputs/checkpoint-1000/special_tokens_map.json\n", + "Deleting older checkpoint [outputs/checkpoint-500] due to args.save_total_limit\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + } + ], + "source": [ + "train_result = trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

Run history:


train/epochโ–โ–…โ–ˆ
train/global_stepโ–โ–…โ–ˆ
train/learning_rateโ–ˆโ–
train/lossโ–ˆโ–
train/total_flosโ–
train/train_lossโ–
train/train_runtimeโ–
train/train_samples_per_secondโ–
train/train_steps_per_secondโ–

Run summary:


train/epoch4
train/global_step1444
train/learning_rate2e-05
train/loss0.2584
train/total_flos2.143084255034573e+16
train/train_loss0.38528
train/train_runtime2161.033
train/train_samples_per_second10.678
train/train_steps_per_second0.668

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run electric-serenity-822 at: https://wandb.ai/nlp15/odqa/runs/0y11f1su
View project at: https://wandb.ai/nlp15/odqa
Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20241015_184122-0y11f1su/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RobertaForQuestionAnswering(\n", + " (roberta): RobertaModel(\n", + " (embeddings): RobertaEmbeddings(\n", + " (word_embeddings): Embedding(32000, 1024, padding_idx=1)\n", + " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n", + " (token_type_embeddings): Embedding(1, 1024)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): RobertaEncoder(\n", + " (layer): ModuleList(\n", + " (0-23): 24 x RobertaLayer(\n", + " (attention): RobertaAttention(\n", + " (self): RobertaSelfAttention(\n", + " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): RobertaSelfOutput(\n", + " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): RobertaIntermediate(\n", + " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): RobertaOutput(\n", + " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ํ—ˆ๊น…ํŽ˜์ด์Šค์— ๋ชจ๋ธ ์—…๋กœ๋“œ" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "/bin/bash: sudo: command not found\n" + ] + } + ], + "source": [ + "!sudo apt-get install git-lfs" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Configuration saved in /tmp/tmpvxgja89o/config.json\n", + "Model weights saved in /tmp/tmpvxgja89o/pytorch_model.bin\n", + "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v2: pytorch_model.bin,config.json\n", + "pytorch_model.bin: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.34G/1.34G [00:56<00:00, 23.7MB/s] \n", + "tokenizer config file saved in /tmp/tmp1q6nsung/tokenizer_config.json\n", + "Special tokens file saved in /tmp/tmp1q6nsung/special_tokens_map.json\n", + "Uploading the following files to ssunbear/klue_roberta_large_finetuned_korquad_v2: tokenizer_config.json,vocab.txt,tokenizer.json,special_tokens_map.json\n" + ] + }, + { + "data": { + "text/plain": [ + "CommitInfo(commit_url='https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v2/commit/32b49af3b113769f65b0ab2392cfb81d4c159962', commit_message='Upload tokenizer', commit_description='', oid='32b49af3b113769f65b0ab2392cfb81d4c159962', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ssunbear/klue_roberta_large_finetuned_korquad_v2', endpoint='https://huggingface.co', repo_type='model', repo_id='ssunbear/klue_roberta_large_finetuned_korquad_v2'), pr_revision=None, pr_num=None)" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoModel\n", + "from transformers import AutoTokenizer\n", + "\n", + "\n", + "# Huggingface Access Token\n", + "ACCESS_TOKEN = #ํ† ํฐ์•„์ด๋”” ์ž…๋ ฅํ•˜์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค.\n", + "\n", + "# Upload to Huggingface\n", + "model.push_to_hub('klue_roberta_large_finetuned_korquad_v2', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n", + "tokenizer.push_to_hub('klue_roberta_large_finetuned_korquad_v2', use_temp_dir=True, use_auth_token=ACCESS_TOKEN)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "import sys\n", + "import wandb\n", + "\n", + "from datasets import DatasetDict\n", + "import evaluate\n", + "import argparse\n", + "from trainer_qa import QuestionAnsweringTrainer\n", + "from transformers import (\n", + " AutoConfig,\n", + " AutoModelForQuestionAnswering,\n", + " AutoTokenizer,\n", + " DataCollatorWithPadding,\n", + " EvalPrediction,\n", + " TrainingArguments,\n", + ")\n", + "from utils_qa import set_seed, check_no_error, postprocess_qa_predictions\n", + "from omegaconf import OmegaConf\n", + "from omegaconf import DictConfig\n", + "from utils.naming import wandb_naming\n", + "from prepare_dataset import prepare_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FINISH\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from CNN_layer_model import CNN_RobertaForQuestionAnswering\n", + "from transformers import (\n", + " AutoConfig,\n", + " AutoModelForQuestionAnswering,\n", + " AutoTokenizer\n", + ")\n", + "\n", + "model_name = \"CurtisJeon/klue-roberta-large-korquad_v1_qa\"\n", + "\n", + "config = AutoConfig.from_pretrained(\n", + " model_name\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " model_name,\n", + " use_fast=True\n", + ")\n", + "model = CNN_RobertaForQuestionAnswering.from_pretrained(\n", + " model_name,\n", + " config=config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import (\n", + " AutoConfig,\n", + " AutoTokenizer\n", + ")\n", + "from custom_model_copy import CNN_RobertaForQuestionAnswering\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Some weights of CNN_RobertaForQuestionAnswering were not initialized from the model checkpoint at CurtisJeon/klue-roberta-large-korquad_v1_qa and are newly initialized: ['cnn_block2.layer_norm.bias', 'cnn_block4.conv2.bias', 'cnn_block5.conv2.weight', 'cnn_block2.conv1.weight', 'cnn_block2.conv1.bias', 'cnn_block4.conv2.weight', 'cnn_block4.layer_norm.bias', 'cnn_block1.layer_norm.bias', 'cnn_block2.layer_norm.weight', 'cnn_block4.layer_norm.weight', 'cnn_block4.conv1.weight', 'cnn_block5.layer_norm.bias', 'cnn_block3.layer_norm.bias', 'cnn_block4.conv1.bias', 'cnn_block1.conv2.weight', 'cnn_block1.layer_norm.weight', 'cnn_block5.layer_norm.weight', 'cnn_block3.conv2.weight', 'cnn_block3.conv2.bias', 'cnn_block5.conv1.bias', 'cnn_block3.conv1.weight', 'cnn_block2.conv2.bias', 'cnn_block3.layer_norm.weight', 'cnn_block1.conv1.weight', 'cnn_block1.conv1.bias', 'cnn_block2.conv2.weight', 'cnn_block5.conv1.weight', 'cnn_block3.conv1.bias', 'cnn_block1.conv2.bias', 'cnn_block5.conv2.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model_name = \"CurtisJeon/klue-roberta-large-korquad_v1_qa\"\n", + "\n", + "config = AutoConfig.from_pretrained(\n", + " model_name\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " model_name,\n", + " use_fast=True\n", + ")\n", + "model = CNN_RobertaForQuestionAnswering.from_pretrained(\n", + " model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CNN_RobertaForQuestionAnswering(\n", + " (roberta): RobertaModel(\n", + " (embeddings): RobertaEmbeddings(\n", + " (word_embeddings): Embedding(32000, 1024, padding_idx=1)\n", + " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n", + " (token_type_embeddings): Embedding(1, 1024)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): RobertaEncoder(\n", + " (layer): ModuleList(\n", + " (0-23): 24 x RobertaLayer(\n", + " (attention): RobertaAttention(\n", + " (self): RobertaSelfAttention(\n", + " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): RobertaSelfOutput(\n", + " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): RobertaIntermediate(\n", + " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): RobertaOutput(\n", + " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (cnn_block1): CNN_block(\n", + " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block2): CNN_block(\n", + " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block3): CNN_block(\n", + " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block4): CNN_block(\n", + " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (cnn_block5): CNN_block(\n", + " (conv1): Conv1d(514, 1028, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1028, 514, kernel_size=(1,), stride=(1,))\n", + " (relu): ReLU()\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..a19f56f --- /dev/null +++ b/src/main.py @@ -0,0 +1,434 @@ +import datetime +import logging +import os +import sys +import numpy as np +from typing import Callable, List, NoReturn, Tuple + +from arguments import DataTrainingArguments, ModelArguments +from datasets import ( + Dataset, + DatasetDict, + Features, + Value, + Sequence, + load_from_disk, + load_metric +) +from qa_trainer import QATrainer +from retrieval_BM25 import BM25SparseRetrieval +from retrieval_hybridsearch import HybridSearch +from retrieval_Dense import DenseRetrieval +from retrieval_2s_rerank import TwoStageReranker + +from transformers import ( + AutoConfig, + #AutoModelForQuestionAnswering + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + TrainingArguments, +) +from utils import set_seed, check_no_error, postprocess_qa_predictions +import wandb +from CNN_layer_model import CNN_RobertaForQuestionAnswering + +logger = logging.getLogger(__name__) + + +def main(): + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + training_args.save_steps = 0 + training_args.logging_steps = 10 + + project_prefix = "[train]" if training_args.do_train else "[eval]" if training_args.do_eval else "[pred]" + wandb.init( + project="odqa", + entity="nlp15", + name=f"{project_prefix} {model_args.model_name_or_path.split('/')[0]}_{(datetime.datetime.now() + datetime.timedelta(hours=9)).strftime('%Y%m%d_%H%M%S')}", + save_code=True, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + logging.info(f"model is from {model_args.model_name_or_path}") + logging.info(f"data is from {data_args.dataset_name}") + + logger.info("Training/evaluation parameters %s", training_args) + + set_seed(training_args.seed) + print(">>> seed:", training_args.seed) + + + datasets = load_from_disk(data_args.dataset_name) + print(datasets) + + config = AutoConfig.from_pretrained( + model_args.config_name + if model_args.config_name is not None + else model_args.model_name_or_path, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name + if model_args.tokenizer_name is not None + else model_args.model_name_or_path, + use_fast=True, + ) + #AutoModelForQuestionAnswering -> CNN_RobertaForQuestionAnswering : CNNlayer์ถ”๊ฐ€ํ•˜์—ฌ ์„ฑ๋Šฅํ–ฅ์ƒ + model = CNN_RobertaForQuestionAnswering.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + ) + + if training_args.do_predict and data_args.eval_retrieval: + datasets = run_sparse_retrieval( + tokenizer.tokenize, datasets, training_args, data_args, + ) + + run_mrc(data_args, training_args, model_args, datasets, tokenizer, model) + + +def prepare_train_features(examples, tokenizer, question_column_name, pad_on_right, context_column_name, max_seq_length, data_args, answer_column_name): + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_token_type_ids=False, # True if bert, False if roberta + padding="max_length" if data_args.pad_to_max_length else False, + ) + + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + offset_mapping = tokenized_examples.pop("offset_mapping") + + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + sequence_ids = tokenized_examples.sequence_ids(i) + + sample_index = sample_mapping[i] + answers = examples[answer_column_name][sample_index] + + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if pad_on_right else 0): + token_start_index += 1 + + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if pad_on_right else 0): + token_end_index -= 1 + + if not ( + offsets[token_start_index][0] <= start_char + and offsets[token_end_index][1] >= end_char + ): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + while ( + token_start_index < len(offsets) + and offsets[token_start_index][0] <= start_char + ): + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + +def run_sparse_retrieval( + tokenize_fn: Callable[[str], List[str]], + datasets: DatasetDict, + training_args: TrainingArguments, + data_args: DataTrainingArguments, + data_path: str = "../data", + context_path: str = "wikipedia_documents.json", +) -> DatasetDict: + + retriever = HybridSearch( + tokenize_fn=tokenize_fn, + # args=data_args, # args๋ฅผ ์ „๋‹ฌ + data_path=data_path, + context_path=context_path + ) + retriever.get_sparse_embedding() + retriever.get_dense_embedding() + + # retriever = BM25SparseRetrieval( + # tokenize_fn=tokenize_fn, + # data_path=data_path, + # context_path=context_path + # ) + # retriever.get_sparse_embedding() + + # retriever = TwoStageReranker( + # tokenize_fn=tokenize_fn, + # args=data_args, # args๋ฅผ ์ „๋‹ฌ + # data_path=data_path, + # context_path=context_path + # ) + + # retriever = DenseRetrieval( + # data_path=data_path, + # context_path=context_path + # ) + # retriever.get_dense_embedding() + + # if data_args.use_faiss: + # retriever.build_faiss(num_clusters=data_args.num_clusters) + # df = retriever.retrieve_faiss( + # datasets["validation"], topk=data_args.top_k_retrieval + # ) + # else: + + # df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval) + # df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval, alpha=data_args.alpha_retrieval) + df = retriever.retrieve(datasets["validation"], topk=10, alpha=0.7) + + + if training_args.do_predict: + f = Features( + { + "context": Value(dtype="string", id=None), + "id": Value(dtype="string", id=None), + "question": Value(dtype="string", id=None), + } + ) + + elif training_args.do_eval: + f = Features( + { + "answers": Sequence( + feature={ + "text": Value(dtype="string", id=None), + "answer_start": Value(dtype="int32", id=None), + }, + length=-1, + id=None, + ), + "context": Value(dtype="string", id=None), + "id": Value(dtype="string", id=None), + "question": Value(dtype="string", id=None), + } + ) + datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)}) + return datasets + + +def prepare_validation_features(examples, tokenizer, question_column_name, pad_on_right, context_column_name, max_seq_length, data_args, answer_column_name): + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_token_type_ids=False, # True if bert, False if roberta + padding="max_length" if data_args.pad_to_max_length else False, + ) + + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + return tokenized_examples + + +def run_mrc( + data_args: DataTrainingArguments, + training_args: TrainingArguments, + model_args: ModelArguments, + datasets: DatasetDict, + tokenizer, + model, +) -> NoReturn: + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + + question_column_name = "question" if "question" in column_names else column_names[0] + context_column_name = "context" if "context" in column_names else column_names[1] + answer_column_name = "answers" if "answers" in column_names else column_names[2] + + pad_on_right = tokenizer.padding_side == "right" + + last_checkpoint, max_seq_length = check_no_error( + data_args, training_args, datasets, tokenizer + ) + + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + + train_dataset = train_dataset.map( + prepare_train_features, + fn_kwargs={ + 'tokenizer': tokenizer, + 'pad_on_right': pad_on_right, + 'max_seq_length': max_seq_length, + 'data_args': data_args, + 'question_column_name': question_column_name, + 'context_column_name': context_column_name, + 'answer_column_name': answer_column_name + }, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_eval or training_args.do_predict: + eval_dataset = datasets["validation"] + + eval_dataset = eval_dataset.map( + prepare_validation_features, + fn_kwargs={ + 'tokenizer': tokenizer, + 'pad_on_right': pad_on_right, + 'max_seq_length': max_seq_length, + 'data_args': data_args, + 'question_column_name': question_column_name, + 'context_column_name': context_column_name, + 'answer_column_name': answer_column_name + }, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + data_collator = DataCollatorWithPadding( + tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None + ) + + metric = load_metric("squad") + + def post_processing_function( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + training_args: TrainingArguments + ) -> EvalPrediction: + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + max_answer_length=data_args.max_answer_length, + output_dir=training_args.output_dir, + ) + formatted_predictions = [ + {"id": k, "prediction_text": v} for k, v in predictions.items() + ] + + if training_args.do_predict: + return formatted_predictions + + elif training_args.do_eval: + references = [ + {"id": ex["id"], "answers": ex[answer_column_name]} + for ex in datasets["validation"] + ] + return EvalPrediction( + predictions=formatted_predictions, label_ids=references + ) + + trainer = QATrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + eval_examples=datasets["validation"] if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + post_process_function=post_processing_function, + compute_metrics=lambda x: metric.compute(predictions=x.predictions, references=x.label_ids), + ) + + if training_args.do_train: + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + checkpoint = model_args.model_name_or_path + else: + checkpoint = None + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + + metrics = train_result.metrics + metrics["train_samples"] = len(train_dataset) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + output_train_file = os.path.join(training_args.output_dir, "train_results.txt") + + with open(output_train_file, "w") as writer: + logger.info("***** Train results *****") + for key, value in sorted(train_result.metrics.items()): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + trainer.state.save_to_json( + os.path.join(training_args.output_dir, "trainer_state.json") + ) + + logger.info("*** Evaluate ***") + + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + + metrics["eval_samples"] = len(eval_dataset) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + predictions = trainer.predict( + test_dataset=eval_dataset, test_examples=datasets["validation"] + ) + logger.info("No metric can be presented because there is no correct answer given. Job done!") + + +if __name__ == "__main__": + main() + diff --git a/src/optimize_retriever.py b/src/optimize_retriever.py new file mode 100644 index 0000000..9889a84 --- /dev/null +++ b/src/optimize_retriever.py @@ -0,0 +1,124 @@ +from rank_bm25 import BM25Plus +import optuna +from sklearn.metrics import ndcg_score +import numpy as np +from datasets import load_from_disk +from transformers import AutoTokenizer, AutoModel +from torch.nn.functional import normalize +import torch +from tqdm import tqdm + +import logging +import datetime +import wandb + +from retrieval_hybridsearch import HybridSearch # HybridSearch ํด๋ž˜์Šค๊ฐ€ ์ •์˜๋œ ํŒŒ์ผ์—์„œ ์ž„ํฌํŠธ +from retrieval_2s_rerank import TwoStageReranker + +# ๋กœ๊ทธ ์„ค์ • +logger = logging.getLogger(__name__) +wandb.init(project="odqa", + name="run_" + (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime("%Y%m%d_%H%M%S"), + entity="nlp15" + ) + +datasets = load_from_disk("../data/train_dataset") + +documents = datasets['train']['context'] +queries = datasets['train']['question'] + +tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned") +dense_tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct') +dense_embeder = AutoModel.from_pretrained( + 'intfloat/multilingual-e5-large-instruct' + ) + +dense_embeds = [] +batch_size = 64 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +dense_embeder.to(device) + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +for i in tqdm(range(0, len(documents), batch_size), desc="Encoding passages"): + batch_contexts = documents[i:i+batch_size] + encoded_input = dense_tokenizer( + batch_contexts, padding=True, truncation=True, return_tensors='pt' + ).to(device) + with torch.no_grad(): + model_output = dense_embeder(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + sentence_embeddings = normalize(sentence_embeddings, p=2, dim=1) + dense_embeds.append(sentence_embeddings.cpu()) + del encoded_input, model_output, sentence_embeddings + torch.cuda.empty_cache() + +dense_embeds = torch.cat(dense_embeds, dim=0) + +retriever = HybridSearch( + tokenize_fn=tokenizer.tokenize, + data_path="../data", + context_path="wikipedia_documents.json" +) +# retriever = TwoStageReranker( +# tokenize_fn=tokenizer.tokenize, +# data_path="../data", +# context_path="wikipedia_documents.json" +# ) + +retriever.get_dense_embedding() +retriever.get_sparse_embedding() + +true_relevance_scores = np.eye(len(documents), dtype=int).tolist() + +retriever.dense_embeds = dense_embeds +# retriever.dense_embeder.dense_embeds = dense_embeds + +def objective(trial): + alpha = trial.suggest_float("alpha", 0.4, 1.0) + k1 = trial.suggest_float("k1", 0.5, 2.0) + b = trial.suggest_float("b", 0.0, 1.0) + delta = trial.suggest_float("delta", 0.0, 1.0) + + retriever.sparse_embeder = BM25Plus([tokenizer.tokenize(doc) for doc in documents], k1=k1, b=b, delta=delta) + + all_scores = [] + all_doc_indices = [] + + for idx, query in enumerate(queries): + scores, contexts, doc_indices = retriever.retrieve(query, topk=20, alpha=alpha) + all_scores.append(scores) + all_doc_indices.append(doc_indices) + + true_relevance_scores = [] + for idx, doc_indices in enumerate(all_doc_indices): + relevance = [1 if doc_idx == idx else 0 for doc_idx in doc_indices] + true_relevance_scores.append(relevance) + + all_scores = np.array(all_scores) + true_relevance_scores = np.array(true_relevance_scores) + + avg_ndcg = ndcg_score(true_relevance_scores, all_scores) + + return avg_ndcg + +class TQDMProgressBar: + def __init__(self, n_trials): + self.pbar = tqdm(total=n_trials) + + def __call__(self, study, trial): + self.pbar.update(1) + +n_trials = 30 # ์ด ์‹œ๋„ ํšŸ์ˆ˜ +progress_bar = TQDMProgressBar(n_trials) +study = optuna.create_study(direction="maximize") +study.optimize(objective, n_trials=n_trials, callbacks=[progress_bar]) + +# ์ตœ์ ์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ถœ๋ ฅ +print("Best alpha:", study.best_params["alpha"]) +print("Best k1:", study.best_params["k1"]) +print("Best b:", study.best_params["b"]) +print("Best delta:", study.best_params["delta"]) \ No newline at end of file diff --git a/src/preprocess_answer.ipynb b/src/preprocess_answer.ipynb new file mode 100644 index 0000000..0251a41 --- /dev/null +++ b/src/preprocess_answer.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import torch\n", + "\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from datasets import load_from_disk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_dir = \"/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ\n", + "dataset_dict = load_from_disk(os.path.join(base_dir, \"data\", \"test_dataset\"))\n", + "\n", + "# 'test' ๋ฐ์ดํ„ฐ์…‹์„ ์„ ํƒํ•˜๊ณ  Pandas ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜\n", + "test_dataset = dataset_dict['validation']\n", + "df = test_dataset.to_pandas()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"rtzr/ko-gemma-2-9b-it\"\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " torch_dtype=torch.bfloat16,\n", + " device_map=\"auto\",\n", + ")\n", + "\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "instruction = \"\"\"๋‹ค์Œ์˜ '์ œ์‹œ ๋‹ต๋ณ€'์—์„œ ํ•„์š”ํ•œ ๋‹ต๋งŒ ๋™์ผํ•˜๊ฒŒ ์ถ”์ถœํ•˜์—ฌ ์ œ์‹œํ•˜์‹œ์˜ค. ์ถ”๊ฐ€์ ์ธ ์ •๋ณด๋‚˜ ์ˆ˜์ • ์—†์ด, '์ œ์‹œ ๋‹ต๋ณ€'์˜ ๋‚ด์šฉ์„ ์‚ฌ์šฉํ•˜์‹œ์˜ค.\n", + "\n", + "\n", + "๋‹ค์Œ์„ ๋ฐ˜๋“œ์‹œ ์ง€ํ‚ค์‹œ์˜ค: ํ•„์š”์‹œ ์กฐ์‚ฌ, ์–ด๋ฏธ๋ฅผ ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ์œผ๋‚˜, '์ œ์‹œ ๋‹ต๋ณ€' ์ด์™ธ์˜ ๋‚ด์šฉ์„ ์ƒ์„ฑ, ์ถ”๊ฐ€ํ•˜๋ฉด ์ ˆ๋Œ€ ์•ˆ๋จ. ์ค‘์š” ํ•ต์‹ฌ ๋‚ด์šฉ์„ ๋‚จ๊ธฐ๊ณ  ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ๋‹ค.\n", + "\n", + "### ์˜ˆ์‹œ ###\n", + "์งˆ๋ฌธ: ์ œ2์ฐจ ์„ธ๊ณ„ ๋Œ€์ „์€ ๋ช‡ ๋…„์— ๋ฐœ๋ฐœํ•˜์˜€๋Š”๊ฐ€? ์ œ์‹œ ๋‹ต๋ณ€: 1939๋…„ 9์›” 1์ผ์— ๋ฐœ๋ฐœํ•˜์˜€๋‹ค.\n", + "์ƒ์„ฑ ๋‹ต๋ณ€: 1939๋…„\n", + "\n", + "### ์งˆ๋ฌธ ###\n", + "{}\n", + "\n", + "### ์ œ์‹œ ๋‹ต๋ณ€ ###\n", + "{}\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"predictions.json\", \"r\") as f:\n", + " predictions = json.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_pred = {}\n", + "\n", + "for i in range(len(df)):\n", + " question = df.loc[i, 'question']\n", + "\n", + " messages = [\n", + " {\"role\": \"user\", \"content\": instruction.format(question, predictions[df.loc[i, \"id\"]])},\n", + " ]\n", + "\n", + " input_ids = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " return_tensors=\"pt\"\n", + " ).to(model.device)\n", + "\n", + " terminators = [\n", + " tokenizer.eos_token_id,\n", + " tokenizer.convert_tokens_to_ids(\"\")\n", + " ]\n", + "\n", + " outputs = model.generate(\n", + " input_ids,\n", + " max_new_tokens=2048,\n", + " eos_token_id=terminators,\n", + " do_sample=True,\n", + " temperature=0.1,\n", + " top_p=0.9,\n", + " )\n", + "\n", + " pred = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)\n", + "\n", + " new_pred[df.loc[i, \"id\"]] = pred.strip()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "with open(os.path.join(base_dir, \"data\", \"predictions_fixed.json\"), \"w\") as f:\n", + " json.dump(new_pred, f, ensure_ascii=False, indent=4)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/qa_trainer.py b/src/qa_trainer.py new file mode 100644 index 0000000..a74543f --- /dev/null +++ b/src/qa_trainer.py @@ -0,0 +1,78 @@ +from transformers import Trainer, is_datasets_available + +if is_datasets_available(): + import datasets + + +class QATrainer(Trainer): + def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.post_process_function = post_process_function + + def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + compute_metrics = self.compute_metrics + self.compute_metrics = None + try: + output = self.prediction_loop( + eval_dataloader, + description="Evaluation", + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if isinstance(eval_dataset, datasets.Dataset): + eval_dataset.set_format( + type=eval_dataset.format["type"], + columns=list(eval_dataset.features.keys()), + ) + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function( + eval_examples, eval_dataset, output.predictions, self.args + ) + metrics = self.compute_metrics(eval_preds) + + self.log(metrics) + else: + metrics = {} + + self.control = self.callback_handler.on_evaluate( + self.args, self.state, self.control, metrics + ) + return metrics + + def predict(self, test_dataset, test_examples, ignore_keys=None): + test_dataloader = self.get_test_dataloader(test_dataset) + + compute_metrics = self.compute_metrics + self.compute_metrics = None + try: + output = self.prediction_loop( + test_dataloader, + description="Evaluation", + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is None or self.compute_metrics is None: + return output + + if isinstance(test_dataset, datasets.Dataset): + test_dataset.set_format( + type=test_dataset.format["type"], + columns=list(test_dataset.features.keys()), + ) + + predictions = self.post_process_function( + test_examples, test_dataset, output.predictions, self.args + ) + return predictions \ No newline at end of file diff --git a/src/retrieval.py b/src/retrieval.py new file mode 100644 index 0000000..9425feb --- /dev/null +++ b/src/retrieval.py @@ -0,0 +1,86 @@ +import datetime +import logging +import os +import sys +import numpy as np +from typing import Callable, List, NoReturn, Tuple + +from arguments import DataTrainingArguments, ModelArguments +from datasets import ( + Dataset, + DatasetDict, + Features, + Value, + Sequence, + load_from_disk, + load_metric +) +from qa_trainer import QATrainer +from retrieval_BM25 import BM25SparseRetrieval +from retrieval_hybridsearch import HybridSearch +from retrieval_Dense import DenseRetrieval +from retrieval_2s_rerank import TwoStageReranker + +from transformers import ( + AutoConfig, + AutoModelForQuestionAnswering, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + TrainingArguments, +) +from utils import set_seed, check_no_error, postprocess_qa_predictions + + +class Retriever: + def __init__( + self, + tokenize_fn, + args, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + name: str + ): + self.retriever = None + if name == "2s_rerank": + self.retriever = TwoStageReranker( + tokenize_fn=tokenize_fn, + args=data_args, + data_path=data_path, + context_path=context_path + ) + elif name == "BM25": + self.retriever = BM25SparseRetrieval( + tokenize_fn=tokenize_fn, + # args=data_args, + data_path=data_path, + context_path=context_path + ) + self.retriever.get_sparse_embedding() + elif name == "Dense": + self.retriever = DenseRetrieval( + data_path=data_path, + context_path=context_path + ) + self.retriever.get_dense_embedding() + elif name == "hybridsearch": + self.retriever = HybridSearch( + tokenize_fn=tokenize_fn, + # args=data_args, + data_path=data_path, + context_path=context_path + ) + self.retriever.get_sparse_embedding() + self.retriever.get_dense_embedding() + elif name == "tfidf": + self.retriever = TFIDFRetrieval( + tokenize_fn=tokenize_fn, + args=data_args, + data_path=data_path, + context_path=context_path + ) + self.retriever.get_sparse_embedding() + + def retrieve(self, query_or_dataset): + return self.retriever.retrieve(query_or_dataset) \ No newline at end of file diff --git a/src/retrieval_2s_rerank.py b/src/retrieval_2s_rerank.py new file mode 100644 index 0000000..834b32a --- /dev/null +++ b/src/retrieval_2s_rerank.py @@ -0,0 +1,146 @@ +import json +import os +import pickle +import time +import torch +import logging +import scipy +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union, NoReturn +from tqdm.auto import tqdm + +import argparse +import numpy as np +import pandas as pd +from datasets import Dataset, concatenate_datasets, load_from_disk +from torch.nn.functional import normalize + +from retrieval_BM25 import BM25SparseRetrieval +from retrieval_Dense import DenseRetrieval + +from sklearn.feature_extraction.text import TfidfVectorizer +from rank_bm25 import BM25Okapi, BM25Plus +from transformers import AutoTokenizer, AutoModel +from utils import set_seed + +set_seed(42) +logger = logging.getLogger(__name__) + +@contextmanager +def timer(name): + t0 = time.time() + yield + logging.info(f"[{name}] done in {time.time() - t0:.3f} s") + + +class TwoStageReranker: + def __init__( + self, + tokenize_fn, + # args, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + ) -> NoReturn: + self.data_path = data_path + with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: + wiki = json.load(f) + + self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()])) + logging.info(f"Lengths of contexts : {len(self.contexts)}") + self.ids = list(range(len(self.contexts))) + + self.sparse_embeder = BM25SparseRetrieval( + tokenize_fn=tokenize_fn, + # args=args, + data_path=data_path, + context_path=context_path + ) + self.dense_embeder = DenseRetrieval( + data_path=data_path, + context_path=context_path + ) + self.sparse_embeds_bool = False + self.dense_embeds_bool = False + + def retrieve_first(self, queries, topk: Optional[int] = 1): + if self.sparse_embeds_bool == False: + self.sparse_embeder.get_sparse_embedding() + self.sparse_embeds_bool = True + f_df = self.sparse_embeder.retrieve(queries, topk=topk) + return f_df + + def retireve_second(self, queries, topk: Optional[int] = 1, contexts=None): + self.dense_embeder.get_dense_embedding(contexts=contexts) + s_df = self.dense_embeder.retrieve(queries, topk=topk) + # self.sparse_embeder.get_sparse_embedding(contexts=contexts) + # s_df = self.sparse_embeder.retrieve(queries, topk=topk) + return s_df + + def retrieve(self, query_or_dataset, topk: Optional[int] = 1): + retrieved_contexts = [] + if isinstance(query_or_dataset, str): + _, doc_indices = self.retrieve_first(query_or_dataset, topk) + retrieved_contexts = doc_indices + elif isinstance(query_or_dataset, Dataset): + for idx, example in enumerate(tqdm(query_or_dataset, desc="Sparse retrieval: ")): + _, doc_indices = self.retrieve_first(example['question'], topk) + retrieved_contexts.append(doc_indices) + + half_topk = int(topk / 3) + + if isinstance(query_or_dataset, str): + second_df = self.retireve_second(query_or_dataset, half_topk, contexts=retrieved_contexts) + return second_df + elif isinstance(query_or_dataset, Dataset): + second_df = [] + for i, example in enumerate(query_or_dataset): + context = retrieved_contexts[i] + doc_scores, doc_indices = self.retireve_second(example['question'], half_topk, contexts=context) + tmp = { + "question": example["question"], + "id": example["id"], + "context": " ".join(doc_indices), + } + second_df.append(tmp) + second_df = pd.DataFrame(second_df) + return second_df + +if __name__ == "__main__": + import argparse + + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--dataset_name", default="../data/train_dataset", type=str) + parser.add_argument("--data_path", default="../data", type=str) + parser.add_argument("--context_path", default="wikipedia_documents.json", type=str) + + args = parser.parse_args() + logging.info(args.__dict__) + + org_dataset = load_from_disk(args.dataset_name) + full_ds = concatenate_datasets( + [ + org_dataset["train"].flatten_indices(), + org_dataset["validation"].flatten_indices(), + ] + ) + logging.info("*" * 40 + " query dataset " + "*" * 40) + logging.info(f"Full dataset: {full_ds}") + + tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned") + + # query = "๋Œ€ํ†ต๋ น์„ ํฌํ•จํ•œ ๋ฏธ๊ตญ์˜ ํ–‰์ •๋ถ€ ๊ฒฌ์ œ๊ถŒ์„ ๊ฐ–๋Š” ๊ตญ๊ฐ€ ๊ธฐ๊ด€์€?" + query = "์œ ๋ น์€ ์–ด๋А ํ–‰์„ฑ์—์„œ ์ง€๊ตฌ๋กœ ์™”๋Š”๊ฐ€?" + + retriever = TwoStageReranker( + tokenize_fn=tokenizer.tokenize, + args=args, + data_path=args.data_path, + context_path=args.context_path, + ) + + with timer("single query by exhaustive search"): + doc_scores, doc_indices = retriever.retrieve(query, topk=5) \ No newline at end of file diff --git a/src/retrieval_BM25.py b/src/retrieval_BM25.py new file mode 100644 index 0000000..7d64a4b --- /dev/null +++ b/src/retrieval_BM25.py @@ -0,0 +1,181 @@ +import json +import os +import pickle +import time +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union + +import argparse +import numpy as np +import pandas as pd +from datasets import Dataset, concatenate_datasets, load_from_disk +from rank_bm25 import BM25Plus +from tqdm.auto import tqdm +from transformers import AutoTokenizer + +from utils import set_seed + +set_seed(42) + +@contextmanager +def timer(name): + t0 = time.time() + yield + print(f"[{name}] done in {time.time() - t0:.3f} s") + + +class BM25SparseRetrieval: + def __init__(self, tokenize_fn, args, data_path: Optional[str] = "../data/", context_path: Optional[str] = "wikipedia_documents.json") -> None: + set_seed(42) + def __init__( + self, + tokenize_fn, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + corpus: Optional[pd.DataFrame] = None + ) -> None: + self.tokenizer = tokenize_fn + self.data_path = data_path + self.args = args + + # ์œ„ํ‚ค ๋ฌธ์„œ ๋กœ๋“œ + with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: + wiki = json.load(f) + + self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()])) + print(f"Lengths of unique contexts : {len(self.contexts)}") + + self.bm25 = None + + + def get_sparse_embedding(self, contexts=None) -> None: + """BM25+๋กœ Passage Embedding์„ ๋งŒ๋“ค๊ณ  ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.""" + pickle_name = "bm25plus_sparse_embedding_optuna.bin" + emd_path = os.path.join(self.data_path, pickle_name) + if contexts is not None: + self.contexts = contexts + tokenized_corpus = [self.tokenizer(doc) for doc in self.contexts] + self.bm25 = BM25Plus(tokenized_corpus, k1=1.7595, b=0.9172, delta=1.1490) + else: + if os.path.isfile(emd_path): + with open(emd_path, "rb") as file: + self.bm25 = pickle.load(file) + print("Embedding pickle load.") + else: + print("Build passage embedding") + tokenized_corpus = [self.tokenizer(doc) for doc in self.contexts] + self.bm25 = BM25Plus(tokenized_corpus, k1=1.7595, b=0.9172, delta=1.1490) # BM25Plus๋กœ ๋ณ€๊ฒฝ ํ›„ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ Optuna test1 ์ ์šฉ + with open(emd_path, "wb") as file: + pickle.dump(self.bm25, file) + print("Embedding pickle saved.") + + + def retrieve(self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1) -> Union[Tuple[List, List], pd.DataFrame]: + assert self.bm25 is not None, "get_sparse_embedding() ๋ฉ”์†Œ๋“œ๋ฅผ ๋จผ์ € ์ˆ˜ํ–‰ํ•ด์ค˜์•ผํ•ฉ๋‹ˆ๋‹ค." + + if isinstance(query_or_dataset, str): + doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk) + # print("[Search query]\n", query_or_dataset, "\n") + + # for i in range(topk): + # print(f"Top-{i+1} passage with score {doc_scores[0][i]:4f}") + # print(self.contexts[doc_indices[0][i]]) + + return (doc_scores, [self.contexts[doc_indices[0][i]] for i in range(topk)]) + + elif isinstance(query_or_dataset, Dataset): + # Retrieveํ•œ Passage๋ฅผ pd.DataFrame์œผ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + with timer("query exhaustive search"): + doc_scores, doc_indices = self.get_relevant_doc_bulk(query_or_dataset["question"], k=topk) + + total = [] + for idx, example in enumerate(tqdm(query_or_dataset, desc="Sparse retrieval: ")): + tmp = { + "question": example["question"], + "id": example["id"], + "context": " ".join([self.contexts[pid] for pid in doc_indices[idx]]), + } + if "context" in example.keys() and "answers" in example.keys(): + tmp["original_context"] = example["context"] + tmp["answers"] = example["answers"] + total.append(tmp) + + return pd.DataFrame(total) + + + def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]: + """๊ฐœ๋ณ„ ์งˆ์˜์— ๋Œ€ํ•œ ์ƒ์œ„ k๊ฐœ์˜ Passage ๊ฒ€์ƒ‰""" + tokenized_query = [self.tokenizer(query)] + result = np.array([self.bm25.get_scores(query) for query in tokenized_query]) + doc_scores = [] + doc_indices = [] + + for scores in result: + sorted_result = np.argsort(scores)[-k:][::-1] + doc_scores.append(scores[sorted_result].tolist()) + doc_indices.append(sorted_result.tolist()) + + return doc_scores, doc_indices + + + def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 1) -> Tuple[List, List]: + """์—ฌ๋Ÿฌ ๊ฐœ์˜ Query๋ฅผ ๋ฐ›์•„ ์ƒ์œ„ k๊ฐœ์˜ Passage ๊ฒ€์ƒ‰""" + tokenized_queries = [self.tokenizer(query) for query in queries] + result = np.array([self.bm25.get_scores(query) for query in tokenized_queries]) + doc_scores = [] + doc_indices = [] + + for scores in result: + sorted_result = np.argsort(scores)[-k:][::-1] + doc_scores.append(scores[sorted_result].tolist()) + doc_indices.append(sorted_result.tolist()) + + return doc_scores, doc_indices + + +if __name__ == "__main__": + import argparse + import logging + + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--dataset_name", default="../data/train_dataset", type=str) + parser.add_argument("--data_path", default="../data", type=str) + parser.add_argument("--context_path", default="wikipedia_documents.json", type=str) + + args = parser.parse_args() + logging.info(args.__dict__) + + org_dataset = load_from_disk(args.dataset_name) + full_ds = concatenate_datasets( + [ + org_dataset["train"].flatten_indices(), + org_dataset["validation"].flatten_indices(), + ] + ) + logging.info("*" * 40 + " query dataset " + "*" * 40) + logging.info(f"Full dataset: {full_ds}") + + tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned") + + retriever = BM25SparseRetrieval( + tokenize_fn=tokenizer.tokenize, + args=args, + data_path=args.data_path, + context_path=args.context_path, + ) + retriever.get_sparse_embedding() + + # query = "๋Œ€ํ†ต๋ น์„ ํฌํ•จํ•œ ๋ฏธ๊ตญ์˜ ํ–‰์ •๋ถ€ ๊ฒฌ์ œ๊ถŒ์„ ๊ฐ–๋Š” ๊ตญ๊ฐ€ ๊ธฐ๊ด€์€?" + query = "์œ ๋ น์€ ์–ด๋А ํ–‰์„ฑ์—์„œ ์ง€๊ตฌ๋กœ ์™”๋Š”๊ฐ€?" + + # test single query + with timer("single query by exhaustive search using bm25"): + scores, indices = retriever.retrieve(query, 20) + for i, context in enumerate(indices): + print(f"Top-{i} ์˜ ๋ฌธ์„œ์ž…๋‹ˆ๋‹ค. ") + print("---------------------------------------------") + print(context) \ No newline at end of file diff --git a/src/retrieval_Dense.py b/src/retrieval_Dense.py new file mode 100644 index 0000000..ebe0d0d --- /dev/null +++ b/src/retrieval_Dense.py @@ -0,0 +1,229 @@ +import json +import os +import time +import logging +from contextlib import contextmanager +from typing import List, NoReturn, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from datasets import Dataset, concatenate_datasets, load_from_disk +from tqdm.auto import tqdm + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModel + +from utils import set_seed + +set_seed(42) +logger = logging.getLogger(__name__) + +@contextmanager +def timer(name): + t0 = time.time() + yield + logging.info(f"[{name}] done in {time.time() - t0:.3f} s") + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +class DenseRetrieval: + def __init__( + self, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + corpus: Optional[pd.DataFrame] = None + ) -> NoReturn: + self.data_path = data_path + with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: + wiki = json.load(f) + + self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()])) + logging.info(f"Lengths of contexts : {len(self.contexts)}") + self.ids = list(range(len(self.contexts))) + + self.tokenize_fn = AutoTokenizer.from_pretrained( + 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' + ) + self.dense_embeder = AutoModel.from_pretrained( + 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' + ) + self.dense_embeds = None + + def get_dense_embedding(self, question=None, contexts=None): + if contexts is not None: + self.contexts = contexts + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.dense_embeder.to(device) + encoded_input = self.tokenize_fn( + self.contexts, padding=True, truncation=True, return_tensors='pt' + ).to(device) + with torch.no_grad(): + model_output = self.dense_embeder(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + self.dense_embeds = sentence_embeddings.cpu() + + if question is None and contexts is None: + pickle_name = "dense_without_normalize_embedding.bin" + emd_path = os.path.join(self.data_path, pickle_name) + + if os.path.isfile(emd_path): + self.dense_embeds = torch.load(emd_path) + print("Dense embedding loaded.") + else: + print("Building passage dense embeddings in batches.") + self.dense_embeds = [] + batch_size = 64 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.dense_embeder.to(device) + + for i in tqdm(range(0, len(self.contexts), batch_size), desc="Encoding passages"): + batch_contexts = self.contexts[i:i+batch_size] + encoded_input = self.tokenize_fn( + batch_contexts, padding=True, truncation=True, return_tensors='pt' + ).to(device) + with torch.no_grad(): + model_output = self.dense_embeder(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + self.dense_embeds.append(sentence_embeddings.cpu()) + del encoded_input, model_output, sentence_embeddings + torch.cuda.empty_cache() + + self.dense_embeds = torch.cat(self.dense_embeds, dim=0) + torch.save(self.dense_embeds, emd_path) + print("Dense embeddings saved.") + elif question is not None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.dense_embeder.to(device) + encoded_input = self.tokenize_fn( + question, padding=True, truncation=True, return_tensors='pt' + ).to(device) + with torch.no_grad(): + model_output = self.dense_embeder(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + return sentence_embeddings.cpu() + + def get_similarity_score(self, q_vec, c_vec): + if isinstance(q_vec, scipy.sparse.spmatrix): + q_vec = q_vec.toarray() + if isinstance(c_vec, scipy.sparse.spmatrix): + c_vec = c_vec.toarray() + + q_vec = torch.tensor(q_vec) + c_vec = torch.tensor(c_vec) + return q_vec.matmul(c_vec.T).numpy() + + def get_cosine_score(self, q_vec, c_vec): + q_vec = q_vec / q_vec.norm(dim=1, keepdim=True) + c_vec = c_vec / c_vec.norm(dim=1, keepdim=True) + return torch.mm(q_vec, c_vec.T).numpy() + + def retrieve( + self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1 + ) -> Union[Tuple[List[float], List[str]], pd.DataFrame]: + assert self.dense_embeds is not None, "You should first execute `get_sparse_embedding()`" + + if isinstance(query_or_dataset, str): + doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk) + logging.info(f"[Search query] {query_or_dataset}") + + for i in range(topk): + logging.info(f"Top-{i+1} passage with score {doc_scores[i]:.6f}") + logging.info(self.contexts[doc_indices[i]]) + + return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)]) + + elif isinstance(query_or_dataset, Dataset): + total = [] + with timer("query exhaustive search"): + doc_scores, doc_indices = self.get_relevant_doc_bulk( + query_or_dataset["question"], k=topk + ) + for idx, example in enumerate(tqdm(query_or_dataset, desc="[Sparse retrieval] ")): + retrieved_contexts = [self.contexts[pid] for pid in doc_indices[idx]] + tmp = { + "question": example["question"], + "id": example["id"], + "context": " ".join(retrieved_contexts), + } + if "context" in example.keys() and "answers" in example.keys(): + tmp["original_context"] = example["context"] + tmp["answers"] = example["answers"] + total.append(tmp) + + cqas = pd.DataFrame(total) + return cqas + + def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]: + with timer("transform"): + dense_qvec = self.get_dense_embedding(question=[query]) + + with timer("query ex search"): + result = self.get_cosine_score(dense_qvec, self.dense_embeds) + # result = self.get_similarity_score(dense_qvec, self.dense_embeds) + sorted_result = np.argsort(result.squeeze())[::-1] + doc_score = result.squeeze()[sorted_result].tolist()[:k] + doc_indices = sorted_result.tolist()[:k] + return doc_score, doc_indices + + def get_relevant_doc_bulk( + self, queries: List[str], k: Optional[int] = 1 + ) -> Tuple[List, List]: + dense_qvec = self.get_dense_embedding(question=queries) + + result = self.get_cosine_score(dense_qvec, self.dense_embeds) + # result = self.get_similarity_score(dense_qvec, self.dense_embeds) + doc_scores = [] + doc_indices = [] + for i in range(result.shape[0]): + sorted_result = np.argsort(result[i, :])[::-1] + doc_scores.append(result[i, :][sorted_result].tolist()[:k]) + doc_indices.append(sorted_result.tolist()[:k]) + return doc_scores, doc_indices + +if __name__ == "__main__": + import argparse + + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--dataset_name", default="../data/train_dataset", type=str) + parser.add_argument("--data_path", default="../data", type=str) + parser.add_argument("--context_path", default="wikipedia_documents.json", type=str) + + args = parser.parse_args() + logging.info(args.__dict__) + + org_dataset = load_from_disk(args.dataset_name) + full_ds = concatenate_datasets( + [ + org_dataset["train"].flatten_indices(), + org_dataset["validation"].flatten_indices(), + ] + ) + logging.info("*" * 40 + " query dataset " + "*" * 40) + logging.info(f"Full dataset: {full_ds}") + + retriever = DenseRetrieval( + data_path=args.data_path, + context_path=args.context_path, + ) + + retriever.get_dense_embedding() + + # query = "๋Œ€ํ†ต๋ น์„ ํฌํ•จํ•œ ๋ฏธ๊ตญ์˜ ํ–‰์ •๋ถ€ ๊ฒฌ์ œ๊ถŒ์„ ๊ฐ–๋Š” ๊ตญ๊ฐ€ ๊ธฐ๊ด€์€?" + query = "์œ ๋ น์€ ์–ด๋А ํ–‰์„ฑ์—์„œ ์ง€๊ตฌ๋กœ ์™”๋Š”๊ฐ€?" + + with timer("single query by exhaustive search"): + scores, contexts = retriever.retrieve(query, topk=5) + + with timer("bulk query by exhaustive search"): + df = retriever.retrieve(full_ds, topk=1) + if "original_context" in df.columns: + df["correct"] = df["original_context"] == df["context"] + logging.info(f'correct retrieval result by exhaustive search: {df["correct"].sum() / len(df)}') \ No newline at end of file diff --git a/src/retrieval_SPLADE.py b/src/retrieval_SPLADE.py new file mode 100644 index 0000000..a2fe9c7 --- /dev/null +++ b/src/retrieval_SPLADE.py @@ -0,0 +1,129 @@ +import json +import os +import time +import logging +from contextlib import contextmanager +from typing import List, NoReturn, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from datasets import Dataset, concatenate_datasets, load_from_disk +from tqdm.auto import tqdm + +import torch +import torch.nn.functional as F +from sparsembed import model, retrieve +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from utils import set_seed + +set_seed(42) +logger = logging.getLogger(__name__) +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +@contextmanager +def timer(name): + t0 = time.time() + yield + logging.info(f"[{name}] done in {time.time() - t0:.3f} s") + + +class SpldRetrieval: + def __init__( + self, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + corpus: Optional[pd.DataFrame] = None + ) -> NoReturn: + self.data_path = data_path + with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: + wiki = json.load(f) + + self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()])) + logging.info(f"Lengths of contexts : {len(self.contexts)}") + self.ids = list(range(len(self.contexts))) + self.df = None + + self.model = model.Splade( + model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device), + tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"), + device=device + ) + self.retriever = retrieve.SpladeRetriever( + key="id", + on=["text"], + model=self.model + ) + + def get_sparse_embedding(self, df=None) -> NoReturn: + if df is None: + self.df = [{'id': i, 'text': c} for i, c in zip(self.ids, self.contexts)] + else: + self.df = df + + def retrieve( + self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1 + ) -> Union[Tuple[List[float], List[str]], pd.DataFrame]: + assert self.df is not None, "You should first execute `get_sparse_embedding()`" + + # print(self.df.to_dict()) + self.retriever = self.retriever.add( + documents=self.df, + batch_size=10, + k_tokens=256, + ) + + result = self.retriever( + query_or_dataset, + k_tokens=20, # Maximum number of activated tokens. + k=100, # Number of documents to retrieve. + batch_size=10 + ) + + doc_scores = [] + doc_indices = [] + for i in range(len(result)): + doc_scores.append([dic['similarity'] for k, dic in enumerate(result[i]) if k < topk]) + doc_indices.append([dic['id'] for k, dic in enumerate(result[i]) if k < topk]) + for i in range(topk): + logging.info(f"Top-{i+1} passage with score {doc_scores[0][i]:4f}") + logging.info(self.contexts[doc_indices[0][i]]) + return doc_scores, doc_indices + + +if __name__ == "__main__": + import argparse + + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--dataset_name", default="../data/train_dataset", type=str) + parser.add_argument("--data_path", default="../data", type=str) + parser.add_argument("--context_path", default="wikipedia_documents.json", type=str) + + args = parser.parse_args() + logging.info(args.__dict__) + + org_dataset = load_from_disk(args.dataset_name) + full_ds = concatenate_datasets( + [ + org_dataset["train"].flatten_indices(), + org_dataset["validation"].flatten_indices(), + ] + ) + logging.info("*" * 40 + " query dataset " + "*" * 40) + logging.info(f"Full dataset: {full_ds}") + + retriever = SpldRetrieval( + data_path=args.data_path, + context_path=args.context_path, + ) + + retriever.get_sparse_embedding() + + query = "๋Œ€ํ†ต๋ น์„ ํฌํ•จํ•œ ๋ฏธ๊ตญ์˜ ํ–‰์ •๋ถ€ ๊ฒฌ์ œ๊ถŒ์„ ๊ฐ–๋Š” ๊ตญ๊ฐ€ ๊ธฐ๊ด€์€?" + + with timer("single query by exhaustive search"): + doc_scores, doc_indices = retriever.retrieve(query, topk=1) \ No newline at end of file diff --git a/src/retrieval_hybridsearch.py b/src/retrieval_hybridsearch.py new file mode 100644 index 0000000..02201ca --- /dev/null +++ b/src/retrieval_hybridsearch.py @@ -0,0 +1,334 @@ +import json +import os +import pickle +import time +import torch +import logging +import scipy +import scipy.sparse +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union, NoReturn +from tqdm.auto import tqdm + +import argparse +import numpy as np +import pandas as pd +from datasets import Dataset, concatenate_datasets, load_from_disk +from torch.nn.functional import normalize + +from sklearn.feature_extraction.text import TfidfVectorizer +from rank_bm25 import BM25Okapi, BM25Plus +from transformers import AutoTokenizer, AutoModel +from utils import set_seed + +set_seed(42) +logger = logging.getLogger(__name__) + +@contextmanager +def timer(name): + t0 = time.time() + yield + logging.info(f"[{name}] done in {time.time() - t0:.3f} s") + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +class HybridSearch: + def __init__( + self, + tokenize_fn, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + corpus: Optional[pd.DataFrame] = None + ) -> NoReturn: + self.data_path = data_path + with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: + wiki = json.load(f) + + self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()])) + logging.info(f"Lengths of contexts : {len(self.contexts)}") + self.ids = list(range(len(self.contexts))) + + self.tokenize_fn=tokenize_fn + + self.dense_model_name= 'intfloat/multilingual-e5-large-instruct' #'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' #'BM-K/KoSimCSE-roberta' #'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' + self.dense_tokenize_fn = AutoTokenizer.from_pretrained( + self.dense_model_name + ) + # self.sparse_embeder = TfidfVectorizer( + # tokenizer=self.tokenize_fn, ngram_range=(1, 2), max_features=50000, + # ) + self.spares_embeder = None + self.dense_embeder = AutoModel.from_pretrained( + self.dense_model_name + ) + self.sparse_embeds = None + self.dense_embeds = None + + def get_sparse_embedding(self, question=None): + vectorizer_path = os.path.join(self.data_path, "BM25Plus_sparse_vectorizer.bin") + embeddings_path = os.path.join(self.data_path, "BM25Plus_sparse_embedding.bin") + # vectorizer_path = os.path.join(self.data_path, "sparse_vectorizer.bin") + # embeddings_path = os.path.join(self.data_path, "sparse_embedding.bin") + + if question is None: + if os.path.isfile(vectorizer_path) and os.path.isfile(embeddings_path): + with open(vectorizer_path, "rb") as f: + self.sparse_embeder = pickle.load(f) + with open(embeddings_path, "rb") as f: + self.sparse_embeds = pickle.load(f) + print("Sparse vectorizer and embeddings loaded.") + else: + print("Fitting sparse vectorizer and building embeddings.") + self.sparse_embeder = BM25Plus([self.tokenize_fn(doc) for doc in self.contexts], k1=1.837782128608009, b=0.587622663072072, delta=1.1490) + # self.sparse_embeds = self.sparse_embeder.fit_transform(self.contexts) + with open(vectorizer_path, "wb") as f: + pickle.dump(self.sparse_embeder, f) + if self.dense_embeds is not None: + with open(embeddings_path, "wb") as f: + pickle.dump(self.sparse_embeds, f) + print("Sparse vectorizer and embeddings saved.") + else: + # self.sparse_embeder๊ฐ€ CountVectorizer, TfidfVectorizer ๋“ฑ ๊ฐ์ฒด ์ผ ๋•Œ์—๋งŒ ์ด ๋ถ€๋ถ„ ์‚ฌ์šฉ + if not hasattr(self.sparse_embeder, 'vocabulary_'): + vectorizer_path = os.path.join(self.data_path, "sparse_vectorizer.bin") + if os.path.isfile(vectorizer_path): + with open(vectorizer_path, "rb") as f: + self.sparse_embeder = pickle.load(f) + print("Sparse vectorizer loaded for transforming the query.") + else: + raise ValueError("The Sparse vectorizer is not fitted. Please run get_sparse_embedding() first.") + return self.sparse_embeder.transform(question) + + + def get_dense_embedding(self, question=None): + if question is None: + model_n = self.dense_model_name.split('/')[1] + pickle_name = f"{model_n}_dense_embedding.bin" + emd_path = os.path.join(self.data_path, pickle_name) + + if os.path.isfile(emd_path): + self.dense_embeds = torch.load(emd_path) + print("Dense embedding loaded.") + else: + print("Building passage dense embeddings in batches.") + self.dense_embeds = [] + batch_size = 64 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.dense_embeder.to(device) + + for i in tqdm(range(0, len(self.contexts), batch_size), desc="Encoding passages"): + batch_contexts = self.contexts[i:i+batch_size] + encoded_input = self.dense_tokenize_fn( + batch_contexts, padding=True, truncation=True, return_tensors='pt' + ).to(device) + with torch.no_grad(): + model_output = self.dense_embeder(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + sentence_embeddings = normalize(sentence_embeddings, p=2, dim=1) + self.dense_embeds.append(sentence_embeddings.cpu()) + del encoded_input, model_output, sentence_embeddings + torch.cuda.empty_cache() + + self.dense_embeds = torch.cat(self.dense_embeds, dim=0) + torch.save(self.dense_embeds, emd_path) + print("Dense embeddings saved.") + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.dense_embeder.to(device) + encoded_input = self.dense_tokenize_fn( + question, padding=True, truncation=True, return_tensors='pt' + ).to(device) + with torch.no_grad(): + model_output = self.dense_embeder(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + sentence_embeddings = normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings.cpu() + + def hybrid_scale(self, dense_score, sparse_score, alpha: float): + if alpha < 0 or alpha > 1: + raise ValueError("Alpha must be between 0 and 1") + + if isinstance(dense_score, torch.Tensor): + dense_score = dense_score.detach().numpy() + if isinstance(sparse_score, torch.Tensor): + sparse_score = sparse_score.detach().numpy() + + def min_max_normalize(score): + return (score - np.min(score)) / (np.max(score) - np.min(score)) + def z_score_normalize(score): + return (score - np.mean(score)) / np.std(score) + + # dense_score_normalized = min_max_normalize(dense_score) + # sparse_score_normalized = min_max_normalize(sparse_score) + # dense_score_normalized = z_score_normalize(dense_score) + # sparse_score_normalized = z_score_normalize(sparse_score) + + # result = (1 - alpha) * dense_score_normalized + alpha * sparse_score_normalized + result = (1 - alpha) * dense_score + alpha * sparse_score + return result + + def get_similarity_score(self, q_vec, c_vec): + # if isinstance(q_vec, scipy.sparse.spmatrix): + # q_vec = q_vec.toarray() + # if isinstance(c_vec, scipy.sparse.spmatrix): + # c_vec = c_vec.toarray() + + # q_vec = torch.tensor(q_vec) + # c_vec = torch.tensor(c_vec) + # return q_vec.matmul(c_vec.T) + + if isinstance(q_vec, scipy.sparse.spmatrix): + q_vec = q_vec.toarray() + if isinstance(c_vec, scipy.sparse.spmatrix): + c_vec = c_vec.toarray() + + q_vec = torch.tensor(q_vec, dtype=torch.float32) + c_vec = torch.tensor(c_vec, dtype=torch.float32) + + if q_vec.ndim == 1: + q_vec = q_vec.unsqueeze(0) + if c_vec.ndim == 1: + c_vec = c_vec.unsqueeze(0) + + similarity_score = torch.matmul(q_vec, c_vec.T) + + return similarity_score + + def get_cosine_score(self, q_vec, c_vec): + q_vec = q_vec / q_vec.norm(dim=1, keepdim=True) + c_vec = c_vec / c_vec.norm(dim=1, keepdim=True) + return torch.mm(q_vec, c_vec.T) + + def retrieve(self, query_or_dataset, topk: Optional[int] = 1, alpha: Optional[float] = 0.7): + # assert self.sparse_embeds is not None, "You should first execute `get_sparse_embedding()`" + assert self.dense_embeds is not None, "You should first execute `get_dense_embedding()`" + + if isinstance(query_or_dataset, str): + doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, alpha, k=topk) + logging.info(f"[Search query] {query_or_dataset}") + + for i in range(topk): + logging.info(f"Top-{i+1} passage with score {doc_scores[i]:4f}") + logging.info(self.contexts[doc_indices[i]]) + + # return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)], doc_indices) # ์ž„์‹œ๋กœ doc_indices ๋ฐ˜ํ™˜ ๊ฐ’์œผ๋กœ ์ถ”๊ฐ€ํ•ด๋‘” ์ƒํƒœ์ž„ ๋‚˜์ค‘์— ์‚ญ์ œํ•  ๊ฒƒ + return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)]) + + elif isinstance(query_or_dataset, Dataset): + total = [] + with timer("query exhaustive search"): + doc_scores, doc_indices = self.get_relevant_doc_bulk( + query_or_dataset["question"], alpha, k=topk + ) + for idx, example in enumerate(tqdm(query_or_dataset, desc="[Hybrid retrieval] ")): + tmp = { + "question": example["question"], + "id": example["id"], + "context": " ".join([self.contexts[pid] for pid in doc_indices[idx]]), + } + if "context" in example.keys() and "answers" in example.keys(): + tmp["original_context"] = example["context"] + tmp["answers"] = example["answers"] + total.append(tmp) + + cqas = pd.DataFrame(total) + return cqas + + def get_relevant_doc(self, query: str, alpha: float, k: Optional[int] = 1) -> Tuple[List, List]: + with timer("transform"): + # sparse_qvec = self.get_sparse_embedding([query]) + dense_qvec = self.get_dense_embedding([query]) + # assert sparse_qvec.nnz != 0, "Error: query contains no words in vocab." + + with timer("query ex search"): + tokenized_query = [self.tokenize_fn(query)] + sparse_score = np.array([self.sparse_embeder.get_scores(query) for query in tokenized_query]) + # sparse_score = self.get_similarity_score(sparse_qvec, self.sparse_embeds) + # dense_score = self.get_cosine_score(dense_qvec, self.dense_embeds) + dense_score = self.get_similarity_score(dense_qvec, self.dense_embeds) + result = self.hybrid_scale(dense_score.numpy(), sparse_score, alpha) + sorted_result = np.argsort(result.squeeze())[::-1] + doc_score = result.squeeze()[sorted_result].tolist()[:k] + doc_indices = sorted_result.tolist()[:k] + return doc_score, doc_indices + + def get_relevant_doc_bulk( + self, queries: List[str], alpha: float, k: Optional[int] = 1 + ) -> Tuple[List, List]: + # sparse_qvec = self.get_sparse_embedding(queries) + dense_qvec = self.get_dense_embedding(queries) + # assert sparse_qvec.nnz != 0, "Error: query contains no words in vocab." + + tokenized_queries = [self.tokenize_fn(query) for query in queries] + sparse_score = np.array([self.sparse_embeder.get_scores(query) for query in tokenized_queries]) + # sparse_score = self.get_similarity_score(sparse_qvec, self.sparse_embeds) + # dense_score = self.get_cosine_score(dense_qvec, self.dense_embeds) + dense_score = self.get_similarity_score(dense_qvec, self.dense_embeds) + result = self.hybrid_scale(dense_score.numpy(), sparse_score, alpha) + doc_scores = [] + doc_indices = [] + for i in range(result.shape[0]): + sorted_result = np.argsort(result[i, :])[::-1] + doc_scores.append(result[i, :][sorted_result].tolist()[:k]) + doc_indices.append(sorted_result.tolist()[:k]) + return doc_scores, doc_indices + + +if __name__ == "__main__": + import argparse + from transformers import AutoTokenizer + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--dataset_name", default="../data/train_dataset", type=str) + parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str) + parser.add_argument("--data_path", default="../data", type=str) + parser.add_argument("--context_path", default="wikipedia_documents.json", type=str) + parser.add_argument("--use_faiss", default=False, type=bool) + + args = parser.parse_args() + logging.info(args.__dict__) + + org_dataset = load_from_disk(args.dataset_name) + if 'train' in org_dataset and 'validation' in org_dataset: + full_ds = concatenate_datasets( + [ + org_dataset["train"].flatten_indices(), + org_dataset["validation"].flatten_indices(), + ] + ) + else: + full_ds = org_dataset + logging.info("*" * 40 + " query dataset " + "*" * 40) + logging.info(f"Full dataset: {full_ds}") + + tokenizer = AutoTokenizer.from_pretrained("HANTAEK/klue-roberta-large-korquad-v1-qa-finetuned") + + retriever = HybridSearch( + tokenize_fn=tokenizer.tokenize, + data_path=args.data_path, + context_path=args.context_path, + ) + retriever.get_dense_embedding() + retriever.get_sparse_embedding() + + # query = "๋Œ€ํ†ต๋ น์„ ํฌํ•จํ•œ ๋ฏธ๊ตญ์˜ ํ–‰์ •๋ถ€ ๊ฒฌ์ œ๊ถŒ์„ ๊ฐ–๋Š” ๊ตญ๊ฐ€ ๊ธฐ๊ด€์€?" + query = "์œ ๋ น์€ ์–ด๋А ํ–‰์„ฑ์—์„œ ์ง€๊ตฌ๋กœ ์™”๋Š”๊ฐ€?" + + with timer("single query by exhaustive search using hybrid search"): + scores, contexts = retriever.retrieve(query, topk=20, alpha=0.0060115995634538455) + + for i, context in enumerate(contexts): + print(f"Top-{i} ์˜ ๋ฌธ์„œ์ž…๋‹ˆ๋‹ค. ") + print("---------------------------------------------") + print(context) + + diff --git a/src/retrieval_tfidf.py b/src/retrieval_tfidf.py new file mode 100644 index 0000000..ce95be1 --- /dev/null +++ b/src/retrieval_tfidf.py @@ -0,0 +1,283 @@ +import json +import os +import pickle +import time +import logging +from contextlib import contextmanager +from typing import List, NoReturn, Optional, Tuple, Union + +import faiss +import numpy as np +import pandas as pd +from datasets import Dataset, concatenate_datasets, load_from_disk +from sklearn.feature_extraction.text import TfidfVectorizer +from tqdm.auto import tqdm + +from utils import set_seed + +set_seed(42) +logger = logging.getLogger(__name__) + +@contextmanager +def timer(name): + t0 = time.time() + yield + logging.info(f"[{name}] done in {time.time() - t0:.3f} s") + + +class TFIDFRetrieval: + def __init__( + self, + tokenize_fn, + data_path: Optional[str] = "../data/", + context_path: Optional[str] = "wikipedia_documents.json", + corpus: Optional[pd.DataFrame] = None + ) -> NoReturn: + set_seed(42) + self.data_path = data_path + with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: + wiki = json.load(f) + + self.contexts = list(dict.fromkeys([v["text"] for v in wiki.values()])) + logging.info(f"Lengths of contexts : {len(self.contexts)}") + self.ids = list(range(len(self.contexts))) + + self.tfidfv = TfidfVectorizer( + tokenizer=tokenize_fn, ngram_range=(1, 2), max_features=50000, + ) + + self.p_embedding = None + self.indexer = None + + def get_sparse_embedding(self) -> NoReturn: + emd_path = os.path.join(self.data_path, "sparse_embedding.bin") + tfidfv_path = os.path.join(self.data_path, "tfidv.bin") + + if os.path.isfile(emd_path) and os.path.isfile(tfidfv_path): + with open(emd_path, "rb") as file: + self.p_embedding = pickle.load(file) + with open(tfidfv_path, "rb") as file: + self.tfidfv = pickle.load(file) + logging.info(f"Embedding pickle shape: {self.p_embedding.shape}") + else: + logging.info("Embedding pickle not found. Building passage embedding...") + self.p_embedding = self.tfidfv.fit_transform(self.contexts) + logging.info(f"Embedding built with shape: {self.p_embedding.shape}") + with open(emd_path, "wb") as file: + pickle.dump(self.p_embedding, file) + with open(tfidfv_path, "wb") as file: + pickle.dump(self.tfidfv, file) + logging.info("Embedding pickle saved.") + + def build_faiss(self, num_clusters=64) -> NoReturn: + indexer_path = os.path.join(self.data_path, f"faiss_clusters{num_clusters}.index") + if os.path.isfile(indexer_path): + logging.info("Load Saved Faiss Indexer.") + self.indexer = faiss.read_index(indexer_path) + + else: + p_emb = self.p_embedding.astype(np.float32).toarray() + emb_dim = p_emb.shape[-1] + + num_clusters = num_clusters + quantizer = faiss.IndexFlatL2(emb_dim) + + self.indexer = faiss.IndexIVFScalarQuantizer( + quantizer, quantizer.d, num_clusters, faiss.METRIC_L2 + ) + self.indexer.train(p_emb) + self.indexer.add(p_emb) + faiss.write_index(self.indexer, indexer_path) + logging.info("Faiss Indexer Saved.") + + def retrieve( + self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1 + ) -> Union[Tuple[List, List], pd.DataFrame]: + assert self.p_embedding is not None, "You should first execute `get_sparse_embedding()`" + + if isinstance(query_or_dataset, str): + doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk) + logging.info(f"[Search query] {query_or_dataset}") + + for i in range(topk): + logging.info(f"Top-{i+1} passage with score {doc_scores[i]:4f}") + logging.info(self.contexts[doc_indices[i]]) + + return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)]) + + elif isinstance(query_or_dataset, Dataset): + total = [] + with timer("query exhaustive search"): + doc_scores, doc_indices = self.get_relevant_doc_bulk( + query_or_dataset["question"], k=topk + ) + for idx, example in enumerate(tqdm(query_or_dataset, desc="[Sparse retrieval] ")): + tmp = { + "question": example["question"], + "id": example["id"], + "context": " ".join([self.contexts[pid] for pid in doc_indices[idx]]), + } + if "context" in example.keys() and "answers" in example.keys(): + tmp["original_context"] = example["context"] + tmp["answers"] = example["answers"] + total.append(tmp) + + cqas = pd.DataFrame(total) + return cqas + + def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]: + with timer("transform"): + query_vec = self.tfidfv.transform([query]) + assert np.sum(query_vec) != 0, "Error: query contains no words in vocab." + + with timer("query ex search"): + result = query_vec * self.p_embedding.T + if not isinstance(result, np.ndarray): + result = result.toarray() + + sorted_result = np.argsort(result.squeeze())[::-1] + doc_score = result.squeeze()[sorted_result].tolist()[:k] + doc_indices = sorted_result.tolist()[:k] + return doc_score, doc_indices + + def get_relevant_doc_bulk( + self, queries: List, k: Optional[int] = 1 + ) -> Tuple[List, List]: + query_vec = self.tfidfv.transform(queries) + assert np.sum(query_vec) != 0, "Error: query contains no words in vocab." + + result = query_vec * self.p_embedding.T + if not isinstance(result, np.ndarray): + result = result.toarray() + doc_scores = [] + doc_indices = [] + for i in range(result.shape[0]): + sorted_result = np.argsort(result[i, :])[::-1] + doc_scores.append(result[i, :][sorted_result].tolist()[:k]) + doc_indices.append(sorted_result.tolist()[:k]) + return doc_scores, doc_indices + + def retrieve_faiss( + self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1 + ) -> Union[Tuple[List, List], pd.DataFrame]: + assert self.indexer is not None, "You should first execute `build_faiss().`" + + if isinstance(query_or_dataset, str): + doc_scores, doc_indices = self.get_relevant_doc_faiss( + query_or_dataset, k=topk + ) + logging.info(f"[Search query] {query_or_dataset}") + + for i in range(topk): + logging.info(f"Top-{i+1} passage with score {doc_scores[i]:4f}") + logging.info(self.contexts[doc_indices[i]]) + + return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)]) + + elif isinstance(query_or_dataset, Dataset): + queries = query_or_dataset["question"] + total = [] + + with timer("query faiss search"): + doc_scores, doc_indices = self.get_relevant_doc_bulk_faiss( + queries, k=topk + ) + for idx, example in enumerate( + tqdm(query_or_dataset, desc="Sparse retrieval: ") + ): + tmp = { + "question": example["question"], + "id": example["id"], + "context": " ".join( + [self.contexts[pid] for pid in doc_indices[idx]] + ), + } + if "context" in example.keys() and "answers" in example.keys(): + tmp["original_context"] = example["context"] + tmp["answers"] = example["answers"] + total.append(tmp) + + return pd.DataFrame(total) + + def get_relevant_doc_faiss( + self, query: str, k: Optional[int] = 1 + ) -> Tuple[List, List]: + query_vec = self.tfidfv.transform([query]) + assert np.sum(query_vec) != 0, "Error: query contains no words in vocab." + + q_emb = query_vec.toarray().astype(np.float32) + with timer("query faiss search"): + D, I = self.indexer.search(q_emb, k) + + return D.tolist()[0], I.tolist()[0] + + def get_relevant_doc_bulk_faiss( + self, queries: List, k: Optional[int] = 1 + ) -> Tuple[List, List]: + query_vecs = self.tfidfv.transform(queries) + assert np.sum(query_vecs) != 0, "Error: query contains no words in vocab." + + q_embs = query_vecs.toarray().astype(np.float32) + D, I = self.indexer.search(q_embs, k) + + return D.tolist(), I.tolist() + + +if __name__ == "__main__": + import argparse + from transformers import AutoTokenizer + + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--dataset_name", default="../data/train_dataset", type=str) + parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str) + parser.add_argument("--data_path", default="../data", type=str) + parser.add_argument("--context_path", default="wikipedia_documents.json", type=str) + parser.add_argument("--use_faiss", default=False, type=bool) + + args = parser.parse_args() + logging.info(args.__dict__) + + org_dataset = load_from_disk(args.dataset_name) + full_ds = concatenate_datasets( + [ + org_dataset["train"].flatten_indices(), + org_dataset["validation"].flatten_indices(), + ] + ) + logging.info("*" * 40, "query dataset", "*" * 40) + logging.info(f"Full dataset: {full_ds}") + + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False,) + + retriever = TFIDFRetrieval( + tokenize_fn=tokenizer.tokenize, + data_path=args.data_path, + context_path=args.context_path, + ) + + # query = "๋Œ€ํ†ต๋ น์„ ํฌํ•จํ•œ ๋ฏธ๊ตญ์˜ ํ–‰์ •๋ถ€ ๊ฒฌ์ œ๊ถŒ์„ ๊ฐ–๋Š” ๊ตญ๊ฐ€ ๊ธฐ๊ด€์€?" + query = "์œ ๋ น์€ ์–ด๋А ํ–‰์„ฑ์—์„œ ์ง€๊ตฌ๋กœ ์™”๋Š”๊ฐ€?" + + if args.use_faiss: + with timer("single query by faiss"): + scores, indices = retriever.retrieve_faiss(query) + + with timer("bulk query by exhaustive search"): + df = retriever.retrieve_faiss(full_ds) + df["correct"] = df["original_context"] == df["context"] + + logging.info(f'correct retrieval result by faiss {df["correct"].sum() / len(df)}') + + else: + retriever.get_sparse_embedding() + with timer("bulk query by exhaustive search"): + df = retriever.retrieve(full_ds) + df["correct"] = df["original_context"] == df["context"] + logging.info(f'correct retrieval result by exhaustive search: {df["correct"].sum() / len(df)}') + + with timer("single query by exhaustive search"): + scores, indices = retriever.retrieve(query) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..d6a586d --- /dev/null +++ b/src/utils.py @@ -0,0 +1,252 @@ +import collections +import json +import logging +import os +import random +from typing import Any, Optional, Tuple + +import numpy as np +import torch +from arguments import DataTrainingArguments +from datasets import DatasetDict +from tqdm.auto import tqdm +from transformers import PreTrainedTokenizerFast, TrainingArguments +from transformers.trainer_utils import get_last_checkpoint + +logger = logging.getLogger(__name__) + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + +def postprocess_qa_predictions( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + null_score_diff_threshold: float = 0.0, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + is_world_process_zero: bool = True, +): + assert (len(predictions) == 2), "`predictions` should be a tuple with two elements (start_logits, end_logits)." + all_start_logits, all_end_logits = predictions + + assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + if version_2_with_negative: + scores_diff_json = collections.OrderedDict() + + logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + for example_index, example in enumerate(tqdm(examples)): + feature_indices = features_per_example[example_index] + + min_null_prediction = None + prelim_predictions = [] + + for feature_index in feature_indices: + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + offset_mapping = features[feature_index]["offset_mapping"] + # `token_is_max_context`, # Optional: delete the answer that is not available in the current function + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + feature_null_score = start_logits[0] + end_logits[0] + if ( + min_null_prediction is None + or min_null_prediction["score"] > feature_null_score + ): + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + start_indexes = np.argsort(start_logits)[-1:-n_best_size - 1:-1].tolist() + end_indexes = np.argsort(end_logits)[-1:-n_best_size - 1:-1].tolist() + + for start_index in start_indexes: + for end_index in end_indexes: + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + if ( + end_index < start_index + or end_index - start_index + 1 > max_answer_length + ): + continue + if ( + token_is_max_context is not None + and not token_is_max_context.get(str(start_index), False) + ): + continue + prelim_predictions.append( + { + "offsets": ( + offset_mapping[start_index][0], + offset_mapping[end_index][1], + ), + "score": start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + } + ) + + if version_2_with_negative: + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + if version_2_with_negative and not any( + p["offsets"] == (0, 0) for p in predictions + ): + predictions.append(min_null_prediction) + + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + if len(predictions) == 0 or ( + len(predictions) == 1 and predictions[0]["text"] == "" + ): + + predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + if not version_2_with_negative: + all_predictions[example["id"]] = predictions[0]["text"] + else: + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + score_diff = ( + null_score + - best_non_null_pred["start_logit"] + - best_non_null_pred["end_logit"] + ) + scores_diff_json[example["id"]] = float(score_diff) + if score_diff > null_score_diff_threshold: + all_predictions[example["id"]] = "" + else: + all_predictions[example["id"]] = best_non_null_pred["text"] + + all_nbest_json[example["id"]] = [ + { + k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items() + } + for pred in predictions + ] + + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, + "predictions.json" if prefix is None else f"predictions_{prefix}".json, + ) + nbest_file = os.path.join( + output_dir, + "nbest_predictions.json" + if prefix is None + else f"nbest_predictions_{prefix}".json, + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, + "null_odds.json" if prefix is None else f"null_odds_{prefix}".json, + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n" + ) + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + "\n" + ) + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w", encoding="utf-8") as writer: + writer.write( + json.dumps(scores_diff_json, indent=4, ensure_ascii=False) + "\n" + ) + + return all_predictions + + +def check_no_error( + data_args: DataTrainingArguments, + training_args: TrainingArguments, + datasets: DatasetDict, + tokenizer, +) -> Tuple[Any, int]: + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + raise ValueError( + "This example script only works for models that have a fast tokenizer. Checkout the big table of models " + "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this " + "requirement" + ) + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if "validation" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + return last_checkpoint, max_seq_length