Skip to content

Commit 41f5abc

Browse files
Merge pull request #1 from jefferson-bercaw/first_attempt
First attempt
2 parents d446a41 + e990d76 commit 41f5abc

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed

vit-classifier.py

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from __future__ import print_function
2+
3+
import glob
4+
from itertools import chain
5+
import os
6+
import random
7+
import zipfile
8+
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import pandas as pd
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
import torch.optim as optim
16+
from linformer import Linformer
17+
from PIL import Image
18+
# from sklearn.model_selection import train_test_split
19+
from torch.optim.lr_scheduler import StepLR
20+
from torch.utils.data import DataLoader, Dataset
21+
from torchvision import datasets, transforms
22+
from tqdm.notebook import tqdm
23+
24+
25+
from readDataFromExcel import getDataFromExcelFile
26+
from vit_pytorch.vit_3d import ViT
27+
28+
29+
def seed_everything(seed):
30+
random.seed(seed)
31+
os.environ['PYTHONHASHSEED'] = str(seed)
32+
np.random.seed(seed)
33+
torch.manual_seed(seed)
34+
# torch.cuda.manual_seed(seed)
35+
# torch.cuda.manual_seed_all(seed)
36+
torch.backends.cudnn.deterministic = True
37+
38+
39+
def Img_and_Label(data_obj):
40+
41+
img_list = []
42+
label_list = []
43+
file_folder = data_obj.imgRootPath
44+
45+
data_dict = data_obj.excelData
46+
for idx in range(len(data_dict)):
47+
img_list.append(data_dict[idx]["img"])
48+
49+
cur_label = data_dict[idx]["label"]
50+
if cur_label == '0':
51+
label_float = float(0)
52+
else:
53+
label_float = float(1)
54+
label_list.append(label_float)
55+
56+
uniq_names = []
57+
num_images = []
58+
label = []
59+
label_list_short = []
60+
61+
for ind, name in enumerate(img_list):
62+
split_name = name.split("_")
63+
subj = split_name[0]
64+
label.append(label_list[ind])
65+
66+
if subj not in uniq_names:
67+
uniq_names.append(subj)
68+
num_images.append(1)
69+
label_list_short.append(label[-1])
70+
else:
71+
index = uniq_names.index(subj)
72+
num_images[index] += 1
73+
74+
files = [[]]
75+
labels = []
76+
ind = 0
77+
for idx, subj in enumerate(uniq_names):
78+
if num_images[uniq_names.index(subj)] != 24:
79+
print("Subject {} has only {} images".format(subj, num_images[uniq_names.index(subj)]))
80+
81+
else:
82+
files.append([])
83+
for img in range(24):
84+
if img < 10:
85+
img_str = "000" + str(img)
86+
else:
87+
img_str = "00" + str(img)
88+
files[ind].append(os.path.join(file_folder, (subj + "_" + img_str + ".bmp")).replace("\\", "/"))
89+
labels.append(label_list_short[ind])
90+
ind += 1
91+
92+
files = files[0:-1]
93+
# files = list(filter(None, files))
94+
# labels = list(filter(None, labels))
95+
return files, labels
96+
97+
98+
class MRIDataset(Dataset):
99+
def __init__(self, data_obj, transform=None):
100+
files, labels = Img_and_Label(data_obj)
101+
self.file_list = files
102+
self.label = labels
103+
self.transform = transform
104+
105+
def __len__(self):
106+
self.filelength = len(self.file_list)
107+
return self.filelength
108+
109+
def __getitem__(self, idx):
110+
imgs = self.file_list[idx]
111+
img = np.zeros((224, 224, 24))
112+
for idx, cur_img in enumerate(imgs):
113+
img_here = np.asarray(Image.open(cur_img))
114+
assert img_here.dtype == 'uint8'
115+
img[:, :, idx] = img_here / (2**8)
116+
117+
img = np.float32(img)
118+
label = np.float32(self.label)
119+
120+
img_transformed = self.transform(img)
121+
label = self.label[idx]
122+
123+
return img_transformed, label
124+
125+
126+
if __name__ == '__main__':
127+
128+
batch_size = 12
129+
epochs = 100
130+
lr = 3e-5
131+
gamma = 0.7
132+
seed = 42
133+
134+
seed_everything(seed)
135+
136+
device = 'cpu'
137+
138+
n_folds = 10
139+
cur_dir = os.getcwd()
140+
print(f"Current Directory: {cur_dir}")
141+
os.makedirs(os.path.join(cur_dir, "saved_models"), exist_ok=True)
142+
143+
excelFilePath = os.path.join(cur_dir,'Fold_Split.xlsx')
144+
imgRootPath = "C:/Users/jrb187/PycharmProjects/FITNet/subset_data/2D_Images"
145+
146+
# Transforms to data
147+
train_transforms = transforms.Compose(
148+
[
149+
transforms.ToTensor(),
150+
]
151+
)
152+
153+
val_transforms = transforms.Compose(
154+
[
155+
transforms.ToTensor(),
156+
]
157+
)
158+
159+
for fold in range(n_folds):
160+
161+
excel_sheet_name_train = 'train_fold' + str(fold)
162+
excel_sheet_name_test = 'valid_fold' + str(fold)
163+
164+
train_obj = getDataFromExcelFile(excelFilePath=excelFilePath, imgRootPath=imgRootPath, excelSheetName=excel_sheet_name_train)
165+
test_obj = getDataFromExcelFile(excelFilePath=excelFilePath, imgRootPath=imgRootPath, excelSheetName=excel_sheet_name_test)
166+
167+
train_dataset = MRIDataset(train_obj, transform=train_transforms)
168+
test_dataset = MRIDataset(test_obj, transform=val_transforms)
169+
170+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
171+
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
172+
173+
model = ViT(image_size=224, channels =1, frames=24, image_patch_size=16, frame_patch_size=1, num_classes=2,
174+
dim=14*14*24, depth=6, heads=8, mlp_dim=2048, dropout=0.1, emb_dropout=0.1)
175+
176+
# Training
177+
criterion = nn.CrossEntropyLoss()
178+
optimizer = optim.Adam(model.parameters(), lr=lr)
179+
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
180+
181+
for epoch in range(epochs):
182+
epoch_loss = 0
183+
epoch_accuracy = 0
184+
185+
for data, label in train_loader:
186+
187+
# Add 1 (channel)
188+
data = data.unsqueeze(1)
189+
assert data.shape == (batch_size, 1, 24, 224, 224)
190+
191+
data = data.to(device)
192+
label = label.to(device)
193+
194+
output = model(data)
195+
loss = criterion(output, label)
196+
197+
optimizer.zero_grad()
198+
loss.backward()
199+
optimizer.step()
200+
201+
acc = (output.argmax(dim=1) == label).float().mean()
202+
epoch_accuracy += acc / len(train_loader)
203+
epoch_loss += loss / len(train_loader)
204+
205+
torch.cuda.empty_cache()
206+
207+
with torch.no_grad():
208+
epoch_val_accuracy = 0
209+
epoch_val_loss = 0
210+
for data, label in test_loader:
211+
data = data.to(device)
212+
label = label.to(device)
213+
214+
val_output = model(data)
215+
val_loss = criterion(val_output, label)
216+
217+
acc = (val_output.argmax(dim=1) == label).float().mean()
218+
epoch_val_accuracy += acc / len(test_loader)
219+
epoch_val_loss += val_loss / len(test_loader)
220+
221+
print(
222+
f"Fold : {fold+1} - Epoch : {epoch + 1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
223+
)
224+
225+
torch.save(model.state_dict(), './saved_models/{}.pt'.format("fold" + str(fold+1)))
226+

0 commit comments

Comments
 (0)