Skip to content

Commit ac7da29

Browse files
authored
KoGPT2 finetuning
1 parent 60dcf10 commit ac7da29

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

main.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def main(epoch, save_path, load_path, samples, data_file_path, batch_size):
150150
for data in data_loader:
151151
optimizer.zero_grad()
152152
data = torch.stack(data[0]) # list of Tensor로 구성되어 있기 때문에 list를 stack을 통해 변환해준다.
153-
# 여기 계속 data[0]으로 해도 괜찮은지 확인하기
154153
data = data.transpose(1,0)
155154
data = data.to(ctx) # 해당 tensor를 GPU에 loading
156155
model = model.to(ctx)
@@ -167,7 +166,7 @@ def main(epoch, save_path, load_path, samples, data_file_path, batch_size):
167166
summary.add_scalar('loss/loss', loss, count)
168167

169168
# generator 진행
170-
if (count > 0 and count % 10000 == 0) or (len(data) < batch_size):
169+
if (count > 0 and count % 1000 == 0) or (len(data) < batch_size):
171170
sent = sample_sequence(model.to("cpu"), tok, vocab, sent="우리", text_size=100, temperature=0.7, top_p=0.8, top_k=40)
172171
sent = sent.replace("<unused0>", "\n") # 비효율적이지만 엔터를 위해서 등장
173172
sent = auto_enter(sent)
@@ -184,7 +183,7 @@ def main(epoch, save_path, load_path, samples, data_file_path, batch_size):
184183
#########################################
185184
count += 1
186185

187-
if (count > 0 and count % 20000 == 0) or (len(data) < batch_size):
186+
if (count > 0 and count % 10000 == 0) or (len(data) < batch_size):
188187
# 모델 저장
189188
try:
190189
torch.save({
@@ -198,4 +197,4 @@ def main(epoch, save_path, load_path, samples, data_file_path, batch_size):
198197
pass
199198

200199
if __name__ == "__main__":
201-
main(args.epoch, args.save_path, args.load_path, args.samples, args.data_file_path, args.batch_size)
200+
main(args.epoch, args.save_path, args.load_path, args.samples, args.data_file_path, args.batch_size)

0 commit comments

Comments
 (0)