@@ -208,7 +208,100 @@ def setup(self, stage: Optional[str] = None):
208208
209209 self .gt_ds = get_projection_dataset (
210210 GroundTruthDataset (gt_train , gt_val , gt_test ),
211- num_angles = self .num_angles , im_shape = 450 , impl = 'astra_cpu' , inner_circle = self .inner_circle )
211+ num_angles = self .num_angles , im_shape = self .gt_shape + (self .gt_shape // 2 - 7 ), impl = 'astra_cpu' ,
212+ inner_circle = self .inner_circle )
213+
214+ tmp_fcds = TRecFourierCoefficientDataset (self .gt_ds , mag_min = None , mag_max = None , part = 'train' ,
215+ img_shape = self .gt_shape )
216+ self .mag_min = tmp_fcds .mag_min
217+ self .mag_max = tmp_fcds .mag_max
218+
219+ def train_dataloader (self , * args , ** kwargs ) -> DataLoader :
220+ return DataLoader (
221+ TRecFourierCoefficientDataset (self .gt_ds , mag_min = self .mag_min , mag_max = self .mag_max , part = 'train' ,
222+ img_shape = self .gt_shape ),
223+ batch_size = self .batch_size , num_workers = 1 )
224+
225+ def val_dataloader (self , * args , ** kwargs ) -> Union [DataLoader , List [DataLoader ]]:
226+ return DataLoader (
227+ TRecFourierCoefficientDataset (self .gt_ds , mag_min = self .mag_min , mag_max = self .mag_max , part = 'validation' ,
228+ img_shape = self .gt_shape ),
229+ batch_size = self .batch_size , num_workers = 1 )
230+
231+ def test_dataloader (self , * args , ** kwargs ) -> Union [DataLoader , List [DataLoader ]]:
232+ return DataLoader (
233+ TRecFourierCoefficientDataset (self .gt_ds , mag_min = self .mag_min , mag_max = self .mag_max , part = 'test' ,
234+ img_shape = self .gt_shape ),
235+ batch_size = 1 )
236+
237+
238+ class CropLoDoPaBFourierTargetDataModule (LightningDataModule ):
239+ IMG_SHAPE = 361
240+
241+ def __init__ (self , batch_size , gt_shape = 361 , num_angles = 15 ):
242+ """
243+ :param root_dir:
244+ :param batch_size:
245+ :param num_angles:
246+ """
247+ super ().__init__ ()
248+ self .batch_size = batch_size
249+ self .gt_shape = gt_shape
250+ self .num_angles = num_angles
251+ self .inner_circle = True
252+ self .gt_ds = None
253+ self .mean = None
254+ self .std = None
255+
256+ def setup (self , stage : Optional [str ] = None ):
257+ lodopab = dival .get_standard_dataset ('lodopab' , impl = 'astra_cpu' )
258+ assert self .gt_shape <= self .IMG_SHAPE , 'GT is larger than original images.'
259+ if self .gt_shape < self .IMG_SHAPE :
260+ crop_off = (362 - self .gt_shape ) // 2
261+ gt_train = np .array ([lodopab .get_sample (i , part = 'train' , out = (False , True ))[1 ][crop_off :- (crop_off + 1 ),
262+ crop_off :- (crop_off + 1 )] for i in
263+ range (4000 )])
264+ gt_val = np .array ([lodopab .get_sample (i , part = 'validation' , out = (False , True ))[1 ][crop_off :- (crop_off + 1 ),
265+ crop_off :- (crop_off + 1 )] for i in
266+ range (400 )])
267+ gt_test = np .array ([lodopab .get_sample (i , part = 'test' , out = (False , True ))[1 ][crop_off :- (crop_off + 1 ),
268+ crop_off :- (crop_off + 1 )] for i in
269+ range (3553 )])
270+ else :
271+ gt_train = np .array (
272+ [lodopab .get_sample (i , part = 'train' , out = (False , True ))[1 ][1 :, 1 :] for i in range (4000 )])
273+ gt_val = np .array (
274+ [lodopab .get_sample (i , part = 'validation' , out = (False , True ))[1 ][1 :, 1 :] for i in range (400 )])
275+ gt_test = np .array (
276+ [lodopab .get_sample (i , part = 'test' , out = (False , True ))[1 ][1 :, 1 :] for i in range (3553 )])
277+
278+ gt_train = torch .from_numpy (gt_train )
279+ gt_val = torch .from_numpy (gt_val )
280+ gt_test = torch .from_numpy (gt_test )
281+
282+ assert gt_train .shape [1 ] == self .gt_shape
283+ assert gt_train .shape [2 ] == self .gt_shape
284+ x , y = torch .meshgrid (torch .arange (- self .gt_shape // 2 + 1 ,
285+ self .gt_shape // 2 + 1 ),
286+ torch .arange (- self .gt_shape // 2 + 1 ,
287+ self .gt_shape // 2 + 1 ))
288+
289+ self .mean = gt_train .mean ()
290+ self .std = gt_train .std ()
291+
292+ gt_train = normalize (gt_train , self .mean , self .std )
293+ gt_val = normalize (gt_val , self .mean , self .std )
294+ gt_test = normalize (gt_test , self .mean , self .std )
295+
296+ circle = torch .sqrt (x ** 2. + y ** 2. ) <= self .gt_shape // 2
297+ gt_train *= circle
298+ gt_val *= circle
299+ gt_test *= circle
300+
301+ self .gt_ds = get_projection_dataset (
302+ GroundTruthDataset (gt_train , gt_val , gt_test ),
303+ num_angles = self .num_angles , im_shape = self .gt_shape + (self .gt_shape // 2 - 7 ), impl = 'astra_cpu' ,
304+ inner_circle = self .inner_circle )
212305
213306 tmp_fcds = TRecFourierCoefficientDataset (self .gt_ds , mag_min = None , mag_max = None , part = 'train' ,
214307 img_shape = self .gt_shape )
0 commit comments