-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdna_classification_lm.py
115 lines (69 loc) · 2.86 KB
/
dna_classification_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""dna-classification-vcu.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/13BLlKCJc5EiKMHP7KEH5hOXyLEIUdX02
# DNA Multi Class Classification
"""
"""## Prepare Google Drive"""
# Run this cell to mount your Google Drive.
local_path = './'
"""## Prepare fastai"""
from fastai import *
from fastai.text import *
"""## Prepare Dataset"""
local_project_path = local_path + 'dna-10class/'
if not os.path.exists(local_project_path):
os.makedirs(local_project_path)
print('local_project_path:', local_project_path)
"""## Create Language Model"""
# df = pd.read_csv(local_project_path + 'combined_sample.csv')
# example_text = df.iloc[1]; print(example_text)
class dna_tokenizer(BaseTokenizer):
def tokenizer(slef, t):
return list(t)
tokenizer = Tokenizer(tok_func=dna_tokenizer, pre_rules=[], post_rules=[], special_cases=[])
# batch size
bs = 96
data_lm = TextLMDataBunch.from_csv(local_project_path, 'combined.csv',
text_cols ='Text', valid_pct= 0.01, tokenizer=tokenizer,
include_bos= False, include_eos=False, bs=bs)
# data_lm.train_ds[0][0].text
# data_lm.train_ds[0][0].data
"""## Create Language Model Learner"""
learn = language_model_learner(data_lm, TransformerXL, drop_mult=0.3, pretrained=False)#.to_fp16()
# from fastai.callbacks.misc import StopAfterNBatches
# learn.callbacks.append(StopAfterNBatches(n_batches=2))
# print(learn)
# learn.lr_find()
# learn.recorder.plot(skip_end = 15)
learn.fit_one_cycle(1, 1e-4, moms=(0.8,0.7))
learn.save('lm-first-transformer-a4')
learn.unfreeze()
learn.fit_one_cycle(1, 1e-4, moms=(0.8,0.7))
learn.save('lm-fine-tuned-transformer-10-1-1_a4')
learn.fit_one_cycle(8, 1e-4, moms=(0.8,0.7))
learn.save('lm-fine-tuned-transformer-10-1-2_a4')
# learn.fit_one_cycle(7, 1e-3, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-1-3')
# learn.fit_one_cycle(10, 1e-3, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-2')
# learn.fit_one_cycle(10, 1e-4, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-3')
# learn.fit_one_cycle(10, 1e-4, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-4')
# learn.fit_one_cycle(10, 1e-5, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-5')
# learn.fit_one_cycle(10, 1e-5, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-6')
# learn.fit_one_cycle(10, 1e-6, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-7')
# learn.fit_one_cycle(10, 1e-6, moms=(0.8,0.7))
# learn.save('lm-fine-tuned-transformer-10-8')
# learn.load('lm-fine-tuned-transformer-10-1-2')
TEXT = "atggcag"
N_WORDS = 40
N_SENTENCES = 2
print("\n".join(learn.predict(TEXT, N_WORDS, temperature=0.75) for _ in range(N_SENTENCES)))
learn.save_encoder('fine_tuned_enc_tranformer_a2')
print(learn.validate())