Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
d4ca644
Setting up GitHub Classroom Feedback
github-classroom[bot] Sep 30, 2024
0598084
add: README
ssunbear Sep 30, 2024
b1a210e
add: data folder (structure only)
jin-jae Oct 1, 2024
9b19845
add: baseline init
jin-jae Oct 6, 2024
e03f52d
add: baseline source code
jin-jae Oct 6, 2024
458a33d
Merge pull request #3 from boostcampaitech7/feature/baseline_code
jin-jae Oct 6, 2024
5d083d7
change default=klue/roberta-large
ssunbear Oct 6, 2024
f789d47
Merge pull request #8 from boostcampaitech7/feature/baseline_code
jin-jae Oct 6, 2024
047ef44
feat: wandb integration
jin-jae Oct 7, 2024
3432fc9
Merge pull request #9 from boostcampaitech7/feature/baseline_code
jin-jae Oct 7, 2024
cc631cb
Add retrieval BM25
ssunbear Oct 9, 2024
2b78008
Add retrieval BM25
ssunbear Oct 9, 2024
eb5f411
Merge pull request #13 from boostcampaitech7/feature/bm25
jin-jae Oct 9, 2024
c8724ea
fix: error with bm25
ssunbear Oct 10, 2024
3aa6993
Merge pull request #18 from boostcampaitech7/feature/bm25
jin-jae Oct 10, 2024
1d6aa5b
chore: requirements.txt, .gitignore
doraemon500 Oct 10, 2024
e10ee40
modify: for window version
simigami Oct 15, 2024
e037f64
Add retrieval_BM25Plus
LHANTAEK Oct 16, 2024
5d8a2de
hantaek_pre_baseline_code
LHANTAEK Oct 17, 2024
23cab5b
Update main.sh
LHANTAEK Oct 17, 2024
5a229e3
Update requirements.txt
LHANTAEK Oct 17, 2024
15cdd77
Update main.py
LHANTAEK Oct 17, 2024
1b8b944
Delete src/requirements.txt
LHANTAEK Oct 17, 2024
c7a5a10
Delete src/retrieval_BM25Plus.py
LHANTAEK Oct 17, 2024
c7f1e0d
Delete src/main.sh
LHANTAEK Oct 17, 2024
35f1672
Add files via upload
LHANTAEK Oct 17, 2024
9c8fa75
Update main.py
LHANTAEK Oct 17, 2024
7d66050
Add CNN_layer_mode.py and Modify main.py
ssunbear Oct 17, 2024
6c6bcd0
Add korquad_finetuning
ssunbear Oct 17, 2024
3aef7f7
add:: retrieval_hybridsearch.py
doraemon500 Oct 18, 2024
2be70e0
add:: retrieval 2s_rerank, Dense, Splade & del:: BM25Plus & rename:: …
doraemon500 Oct 20, 2024
e6994f8
chore:: retriever tasks
doraemon500 Oct 21, 2024
d04a505
add: basline with seed fixed
jin-jae Oct 22, 2024
6f6a4bf
add:: optuna for retriever
doraemon500 Oct 22, 2024
3f95b73
add: wandb options, run python script
jin-jae Oct 22, 2024
9e2d92c
Merge pull request #26 from boostcampaitech7/feature/baseline_seed_fix
jin-jae Oct 22, 2024
2b44929
Merge pull request #26 from boostcampaitech7/feature/baseline_seed_fix
jin-jae Oct 22, 2024
c4f3e90
Merge branch 'develop' of https://github.com/boostcampaitech7/level2-…
simigami Oct 23, 2024
3439495
Merge branch 'develop' of https://github.com/boostcampaitech7/level2-…
simigami Oct 23, 2024
9ddb885
modify: for window version
simigami Oct 23, 2024
43a913b
modify: for window version
simigami Oct 23, 2024
46e8aa7
modify: gitignore
simigami Oct 23, 2024
6b751e0
Merge branch 'develop' into feature/HybridSearch
simigami Oct 23, 2024
e27d82f
Merge branch 'develop' into feature/HybridSearch
simigami Oct 23, 2024
d65871c
Add ensemble file
ssunbear Oct 23, 2024
e9edf80
Add ensemble file
ssunbear Oct 23, 2024
f531f0f
modify: window feat
simigami Oct 23, 2024
f277b13
modify: window feat
simigami Oct 23, 2024
d540a3e
modify: working on mecab tokenizer
simigami Oct 23, 2024
56ed50f
modify: working on mecab tokenizer
simigami Oct 23, 2024
26a1800
fix: requirements error
jin-jae Oct 27, 2024
5befc83
fix: requirements error
jin-jae Oct 27, 2024
a59bad7
Merge pull request #28 from boostcampaitech7/feature/jungmin
jin-jae Oct 27, 2024
c9905f7
Merge pull request #28 from boostcampaitech7/feature/jungmin
jin-jae Oct 27, 2024
7d3bec6
Merge pull request #27 from boostcampaitech7/feature/ensemble
jin-jae Oct 27, 2024
9d84400
Merge pull request #27 from boostcampaitech7/feature/ensemble
jin-jae Oct 27, 2024
b4854b2
add: llm answer post-preprocessing
jin-jae Oct 27, 2024
fda22a9
add: llm answer post-preprocessing
jin-jae Oct 27, 2024
995565c
Merge pull request #29 from boostcampaitech7/feature/answer_preprocess
jin-jae Oct 27, 2024
ea86708
Merge pull request #29 from boostcampaitech7/feature/answer_preprocess
jin-jae Oct 27, 2024
d27140b
add: README contents
jin-jae Oct 27, 2024
6f4ac9a
add: README contents
jin-jae Oct 27, 2024
e6a20d7
fix: merge conflict
jin-jae Oct 27, 2024
9709ffb
Merge branch 'develop' into feature/korquad_finetuning
jin-jae Oct 27, 2024
963621f
Merge pull request #30 from boostcampaitech7/feature/korquad_finetuning
jin-jae Oct 27, 2024
fe99fb5
Update README.md
LHANTAEK Oct 28, 2024
2aa653b
Add:: 역할 분담
LHANTAEK Oct 28, 2024
571b102
Update README.md
simigami Oct 28, 2024
6759383
Update README.md
ssunbear Oct 28, 2024
849885c
Update README.md
simigami Oct 28, 2024
5c59003
Add files via upload
simigami Oct 28, 2024
bbdce0f
Update README.md
simigami Oct 28, 2024
0d35434
fix: source code structure
jin-jae Oct 28, 2024
16ed97c
Merge pull request #31 from boostcampaitech7/develop
jin-jae Oct 28, 2024
92ab3f1
fix: leaderboard rank
jin-jae Oct 29, 2024
8519258
Merge pull request #32 from boostcampaitech7/develop
jin-jae Oct 29, 2024
20d37f3
Update README.md
doraemon500 Oct 31, 2024
85b789c
Update README.md
ssunbear Oct 31, 2024
7f618b6
Update README.md
ssunbear Nov 7, 2024
9dc9e9a
Add wrap-up report
ssunbear Nov 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# python
__pycache__
.idea

# dataset
data
code
EDA

# outputs
models
output

# src
src/test
src/wandb

# wandb
wandb
163 changes: 163 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
<div align='center'>

# 🏆 LV.2 NLP 프로젝트 : Open-Domain Question Answering

</div>

## ✏️ 대회 소개

| 특징 | 설명 |
|:------:| --- |
| 대회 주제 | 네이버 부스트캠프 AI Tech 7기 NLP Track의 Level 2 도메인 기초 대회 'Open-Domain Question Answering (Machine Reading Comprehension)'입니다. |
| 대회 설명 | 주어지는 Documents의 내용을 기반으로 질문이 주어지면, 그 질문에 대한 정확한 답변을 문서에서 찾아내는 것을 목표로 합니다. |
| 데이터 구성 | 데이터는 위키피디아의 내용으로 대부분 이루어진 문서 데이터, 그리고 Question과 Answer로 구성되어 있습니다. |
| 평가 지표 | 답변을 정확히 추출하는지를 확인하기 위해 EM(Exact Match) 지표가 사용되었습니다.|


## 🎖️ Leader Board
프로젝트 결과 Public 리더보드 2등, Private 리더보드 2등을 기록하였습니다.
### 🥈 Public Leader Board (2위)
![leaderboard_mid](./docs/leaderboard_mid.png)

### 🥈 Private Leader Board (2위)
![leaderboard_final](./docs/leaderboard_final.png)

## 👨‍💻 15조가십오조 멤버
<div align='center'>

| 김진재 [<img src="./docs/github_official_logo.png" width=20 style="vertical-align:middle;" />](https://github.com/jin-jae) | 박규태 [<img src="./docs/github_official_logo.png" width=20 style="vertical-align:middle;" />](https://github.com/doraemon500) | 윤선웅 [<img src="./docs/github_official_logo.png" width=20 style="vertical-align:middle;" />](https://github.com/ssunbear) | 이정민 [<img src="./docs/github_official_logo.png" width=20 style="vertical-align:middle;" />](https://github.com/simigami) | 임한택 [<img src="./docs/github_official_logo.png" width=20 style="vertical-align:middle;" />](https://github.com/LHANTAEK)
|:-:|:-:|:-:|:-:|:-:|
| ![김진재](https://avatars.githubusercontent.com/u/97018331) | ![박규태](https://avatars.githubusercontent.com/u/64678476) | ![윤선웅](https://avatars.githubusercontent.com/u/117508164) | ![이정민](https://avatars.githubusercontent.com/u/46891822) | ![임한택](https://avatars.githubusercontent.com/u/143519383) |

</div>


## 👼 역할 분담
<div align='center'>

|팀원 | 역할 |
|------| --- |
| 김진재 | (팀장) 베이스라인 코드 작성 및 개선, 프로젝트 매니징 및 환경 관리, 조사 전처리 알고리즘 개발, 새로운 접근 방법론 제안, 앙상블 |
| 박규태 | 데이터 특성 분석, EDA, Retrieval 구현, 비교 실험 및 개선(하이브리드 서치. Re-ranking, Dense, SPLADE 등등), Reader 모델 파인튜닝 |
| 윤선웅 | KorQuAD 1.0 데이터 증강, 모델 파인튜닝, Reader 모델 개선(CNN layer 추가), Retrieval 모델 구현(BM25), 앙상블 |
| 이정민 | 데이터 증강 (AEDA, Truncation 등), Question 데이터셋 튜닝, Korquad 데이터셋 튜닝 |
| 임한택 | EDA, Retrieval 모델 개선(BM25Plus, Re-ranking 하이퍼파라미터 최적화), Reader 모델 개선(PLM 선정 및 Trainer 파라미터 최적화) |

</div>


## 🏃 프로젝트
### 🖥️ 프로젝트 개요
|개요| 설명 |
|:------:| --- |
| 주제 | 기계 독해 MRC (Machine Reading Comprehension) 중 ‘Open-Domain Question Answering’ 를 주제로, 주어진 질의와 관련된 문서를 탐색하고, 해당 문서에서 적절한 답변을 찾거나 생성하는 task를 수행 |
| 구조 | Retrieval 단계와 Reader 단계의 two-stage 구조 사용 |
| 평가 지표 | 평가 지표로는 EM Score(Exact Match Score)이 사용되었고, 모델이 예측한 text와 정답 text가 글자 단위로 완전히 똑같은 경우에만 점수 부여 |
| 개발 환경 | `GPU` : Tesla V100 Server 4대, `IDE` : Vscode, Jupyter Notebook |
| 협업 환경 | Notion(진행 상황 공유), Github(코드 및 데이터 공유), Slack(실시간 소통), W&B(시각화, 하이퍼파라미터 튜닝) |

### 📅 프로젝트 타임라인
- 프로젝트는 2024-09-30 ~ 2024-10-25까지 진행되었습니다.
![타임라인](./docs/타임라인.png)

### 🕵️ 프로젝트 진행
- 프로젝트를 진행하며 단계별로 실험하여 적용한 내용들을 아래와 같습니다.


| 프로세스 | 설명 |
|:-----------------:| --- |
| 데이터 처리 | AEDA, Swap Sentence, Truncation, Mecab을 활용한 Question 강조, LLM기반 조사제거 |
| 모델 Finetuning | Korquad dataset 추가, Korquad1 PLM에 Korquad2 데이터셋 fine-tuning |
| Retriever 모델 개선 | BM25Plus, DPR, Hybrid Search, Re-rank(2-stage) |
| Reader 모델 개선 | CNN Layer 추가, Head Customizing, Dropout, Learning rate 튜닝 |
| 앙상블 방법 | Soft Voting: nbest_predictions.json에서 제공하는 단어별 확률값을 활용해서, 각 파일에서 단어의 확률값을 평균낸 후 가장 높은 값을 선택하는 방식 |


### 🤖 Ensemble
| 번호 | 모델+기법 | EM(Public) |
|------|--------------------------------------------------------|------------|
| 1 | uomnf97+BM25+CNN | 66.67 |
| 7 | Curtis+CNN+dropout(only_FC_0.05)+BM25Plus | 66.25 |
| 8 | Curtis+Truncation | 66.25 |
| 9 | HANTAEK_hybrid_optuna_topk20(k1=1.84) | 63.15 |
| 10 | HANTAEK_hybrid_optuna_topk20(k1=0.73) | 63.75 |
| 11 | HANTAEK_hybrid_optuna_topk10(k1=0.73) | 63.75 |
| 12 | uomnf97+BM25 | 67.08 |
| 13 | uomnf97+CNN+Re_rank500_20+Cosine | 67.08 |
| 14 | curtis+CNN+Re_rank_500_20 | 65.42 |
| 15 | nlp04_finetuned+CNN+BM25Plus+epoch1_predictions | 67.5 |

### 📃 Results
| 최종제출 | Ensemble | EM(Public) | EM(Private) |
|----------|----------------------------------------------------|------------|--------------|
| O | 모델 7,8,9,10,11,12,13 1:1:1:1:1:2:3 앙상블 + 조사 LLM | **77.08** | 71.11 |
| O | 모델 14,8,15,10,11,12,13 1:1:1:1:2:3 앙상블 + 조사 LLM | **77.08** | 71.67 |
| | 모델 1,7,8,9,10,11,12 평균앙상블 + 조사 LLM | 76.67 | 71.67 |
| | 모델 7,8,9,10,11,12,13 1:1:1:2:2:2 앙상블 + 조사 LLM | 76.67 | 70.83 |
| 1st SOTA | 모델 15,9,10,11,12,13 1:1:1:1:3:3 앙상블 + 조사 LLM | 75.42 | **74.17** |
| | 모델 7,8,9,10,11,12,13 1:1:1:1:2:3 앙상블 + 조사 LLM(n=5) | 75.42 | 71.67 |
| | 모델 7,8,9,10,11,12,13,14 1:1:1:1:2:2 앙상블 + 조사 LLM | 74.58 | 71.11 |
| 2nd SOTA | 모델 14,8,9,10,11,12,13 1:1:1:1:2:3 앙상블 + 조사 LLM | 74.58 | **72.22** |



## 📁 프로젝트 구조
프로젝트 폴더 구조는 다음과 같습니다.
```
level2-mrc-nlp-15
├── data
│ ├── test_dataset
│ ├── train_dataset
│ └── wikipedia_documents.json
├── docs
│ ├── github_official_logo.png
│ ├── leaderboard_final.png
│ └── leaderboard_mid.png
├── models
├── output
├── README.md
├── requirements.txt
├── run.py
└── src
├── arguments.py
├── CNN_layer_model.py
├── data_analysis.py
├── ensemble
│ ├── probs_voting_ensemble_n.py
│ ├── probs_voting_ensemble.py
│ └── scores_voting_ensemble.py
├── korquad_finetuning_v2.ipynb
├── main.py
├── optimize_retriever.py
├── preprocess_answer.ipynb
├── qa_trainer.py
├── retrieval_2s_rerank.py
├── retrieval_BM25.py
├── retrieval_Dense.py
├── retrieval_hybridsearch.py
├── retrieval.py
├── retrieval_SPLADE.py
├── retrieval_tfidf.py
├── utils.py
└── wandb
```

### 📦 src 폴더 구조 설명
- arguments.py : 데이터 증강을 하는 파일
- CNN_layer_model.py : PLM에 CNN Layer를 추가한 클래스 파일
- data_analysis.py : 데이터셋을 분석하는 파일
- ensemble : 모델 앙상블을 하는 폴더 (Soft, Hard 지원)
- main.py : 모델 train, eval, prediction 을 수행하는 파일
- optimize_retriever.py : 리트리버의 하이퍼파라미터를 최적화 하는 파일
- qa_trainer.py : MRC Task에 대한 커스텀 Trainer 클래스 파일
- retrieval_2s_rerank.py : rerank 리트리버 파일
- retrieval_BM25.py : bm25 리트리버 파일
- retrieval_Dense.py : DPR 리트리버 파일
- retrieval_hybridsearch.py : hybrid-search 리트리버 파일
- retrieval_SPLADE.py : SPLADE 리트리버 파일
- retrieval_tfidf.py : TF-IDF 리트리버 파일


### 💾 Installation
- `python=3.10` 환경에서 requirements.txt를 pip로 install 합니다. (```pip install -r requirements.txt```)
- `python run.py`를 입력하여 프로그램을 실행합니다.
Binary file added [NLP-15] Wrap-Up Report of ODQA.pdf
Binary file not shown.
Empty file added data/.gitkeep
Empty file.
Binary file added docs/github_official_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/leaderboard_final.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/leaderboard_mid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/타임라인.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added models/.gitkeep
Empty file.
Empty file added output/.gitkeep
Empty file.
17 changes: 17 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
datasets==2.15.0
faiss-gpu==1.7.2
networkx==3.1
rank-bm25==0.2.2
scikit-learn==1.4.0
torchaudio==2.1.0
torchvision==0.16.0
nltk==3.9.1
sentence_transformers==2.2.2
sentencepiece==0.2.0
tokenizers==0.13.0
huggingface_hub==0.24.7
ipykernel==6.29.5
scipy==1.7.3
torch==2.1.0
transformers==4.25.1
wandb==0.18.3
59 changes: 59 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3

import os
import subprocess
from datetime import datetime, timedelta

# Get current time (UTC + 9 hours)
current_time = datetime.utcnow() + timedelta(hours=9)
current_time_str = current_time.strftime('%Y%m%d_%H%M%S')

# Root directory (adjust this if necessary)
root_dir = os.getcwd()
#root_dir = os.path.join(os.sep, 'data', 'ephemeral', 'home', 'level2-mrc-nlp-15')

# Ensure root directory exists
if not os.path.exists(root_dir):
raise FileNotFoundError(f"The root directory {root_dir} does not exist. Please adjust the path accordingly.")

# Set up directories
train_dir = os.path.join(root_dir, 'models', f'train_{current_time_str}')
predict_dir = os.path.join(root_dir, 'output', f'test_{current_time_str}')
predict_dataset_name = os.path.join(root_dir, 'data', 'test_dataset')

# Change to src directory
src_dir = os.path.join(root_dir, 'src')
if not os.path.exists(src_dir):
raise FileNotFoundError(f"The source directory {src_dir} does not exist. Please adjust the path accordingly.")
os.chdir(src_dir)

# Perform training
subprocess.run([
"python", "main.py",
"--output_dir", train_dir,
"--do_train",
"--overwrite_output_dir",
"--per_device_train_batch_size", "16",
"--learning_rate", "1e-5",
"--num_train_epochs", "3"
], check=True)

# Perform evaluation (optional)
eval_dir = os.path.join(root_dir, 'output', f'train_dataset_{current_time_str}')
subprocess.run([
"python", "main.py",
"--output_dir", eval_dir,
"--do_eval"
], check=True)

# Perform prediction (inference)
subprocess.run([
"python", "main.py",
"--output_dir", predict_dir,
"--dataset_name", predict_dataset_name,
"--model_name_or_path", train_dir,
"--do_predict"
], check=True)

# Print Done
print(f"All Done. Check the output in {predict_dir}")
117 changes: 117 additions & 0 deletions src/CNN_layer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Optional, Union, Tuple, List

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import RobertaPreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers.models.roberta.modeling_roberta import RobertaModel


class CNN_block(nn.Module):
def __init__(self, input_size, hidden_size):
super(CNN_block, self).__init__()
self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=input_size, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(in_channels=input_size, out_channels=input_size, kernel_size=1)
self.relu = nn.ReLU()
self.layer_norm = nn.LayerNorm(hidden_size)

def forward(self, x):
# Transpose the input to match Conv1d input shape (batch_size, channels, sequence_length)
x = x.transpose(1, 2)
output = self.conv1(x)
output = self.conv2(output)
output = x + self.relu(output)
# Transpose back to original shape (batch_size, sequence_length, channels)
output = output.transpose(1, 2)
output = self.layer_norm(output)
return output

class CNN_RobertaForQuestionAnswering(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)

self.cnn_block1 = CNN_block(config.hidden_size, config.hidden_size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

중복된 코드는 가독성을 떨어트릴 수도 있으니 nn.ModuleList를 사용해도 좋을 것 같아요.

self.cnn_blocks = nn.ModuleList([
CNN_block(config.hidden_size, config.hidden_size) for _ in range(5)
])

self.cnn_block2 = CNN_block(config.hidden_size, config.hidden_size)
self.cnn_block3 = CNN_block(config.hidden_size, config.hidden_size)
self.cnn_block4 = CNN_block(config.hidden_size, config.hidden_size)
self.cnn_block5 = CNN_block(config.hidden_size, config.hidden_size)

self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]

# Apply CNN layers
sequence_output = self.cnn_block1(sequence_output)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

위에서 nn.ModuleList사용하면 여기 forwarding에서도 아래와 같이 바꿀 수 있습니다.

for cnn_block in self.cnn_blocks:
sequence_output = cnn_block(sequence_output)

sequence_output = self.cnn_block2(sequence_output)
sequence_output = self.cnn_block3(sequence_output)
sequence_output = self.cnn_block4(sequence_output)
sequence_output = self.cnn_block5(sequence_output)

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()

total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading