1
+ import torch
2
+ import nibabel as nb
3
+ import os
4
+ import numpy as np
5
+
6
+ #Dataset
7
+ class brats_dataset (torch .utils .data .Dataset ):
8
+ def __init__ (self ,data_folders ):
9
+ self .data_list = []
10
+
11
+ #Perform necessary input data preparation in this function
12
+ #add each input example into the data_last function
13
+ #takes in a list of folders and processes the data contained
14
+
15
+ # U net requires all dimensions be divisible by 8 (by default)
16
+ # or we'd have to manually do the padding in the U-net model
17
+ # no padding="valid" exists in Pytorch for... reasons?
18
+ for i , folder in enumerate (data_folders ):
19
+ i_str = folder [- 3 :]
20
+
21
+ f_flair = nb .load (os .path .join (folder ,'BraTS20_Training_%s_flair.nii' % i_str ),mmap = False ).get_fdata ()
22
+ f_seg = nb .load (os .path .join (folder ,'BraTS20_Training_%s_seg.nii' % i_str ),mmap = False ).get_fdata ()
23
+ f_t1ce = nb .load (os .path .join (folder ,'BraTS20_Training_%s_t1ce.nii' % i_str ),mmap = False ).get_fdata ()
24
+ f_t1 = nb .load (os .path .join (folder ,'BraTS20_Training_%s_t1.nii' % i_str ),mmap = False ).get_fdata ()
25
+ f_t2 = nb .load (os .path .join (folder ,'BraTS20_Training_%s_t2.nii' % i_str ),mmap = False ).get_fdata ()
26
+
27
+ f_flair = torch .as_tensor (np .expand_dims (np .pad (f_flair , [(0 , 0 ), (0 , 0 ), (2 , 3 )]), axis = 0 )).half ()
28
+ f_t1 = torch .as_tensor (np .expand_dims (np .pad (f_t1 , [(0 , 0 ), (0 , 0 ), (2 , 3 )]), axis = 0 )).half ()
29
+ f_t2 = torch .as_tensor (np .expand_dims (np .pad (f_t2 , [(0 , 0 ), (0 , 0 ), (2 , 3 )]), axis = 0 )).half ()
30
+ f_seg = torch .as_tensor (np .expand_dims (np .pad (f_seg , [(0 , 0 ), (0 , 0 ), (2 , 3 )]), axis = 0 )).half ()
31
+ f_t1ce = torch .as_tensor (np .expand_dims (np .pad (f_t1ce , [(0 , 0 ), (0 , 0 ), (2 , 3 )]), axis = 0 )).half ()
32
+
33
+
34
+ concat = torch .cat ([f_t1 , f_t1ce , f_t2 , f_flair ], axis = 0 )
35
+
36
+ self .data_list .append ([concat , f_seg ])
37
+ def __len__ (self ):
38
+ return len (self .data_list )
39
+ def __getitem__ (self , index ):
40
+ return self .data_list [index ]
0 commit comments