Skip to content

Commit 4a0fff1

Browse files
author
tibuch
authored
Merge pull request #16 from juglab/merger
Merger
2 parents 25ec06f + 8596e35 commit 4a0fff1

13 files changed

Lines changed: 437 additions & 102 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Build Python package:
66
`python setup.py bdist_wheel`
77

88
Build singularity recipe:
9-
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.19-py3-none-any.whl /fourier_image_transformers-0.1.19-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.19-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.19.Singularity`
9+
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.20-py3-none-any.whl /fourier_image_transformers-0.1.20-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.20-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.20.Singularity`
1010

1111
Build singularity container:
12-
`sudo singularity build fit_v0.1.19.simg v0.1.19.Singularity`
12+
`sudo singularity build fit_v0.1.20.simg v0.1.20.Singularity`

fit/datamodules/super_res/SRecDataModule.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class MNISTSResFourierTargetDataModule(LightningDataModule):
19-
IMG_SHAPE = 28
19+
IMG_SHAPE = 27
2020

2121
def __init__(self, root_dir, batch_size):
2222
"""
@@ -38,8 +38,9 @@ def setup(self, stage: Optional[str] = None):
3838
mnist_train_val = MNIST(self.root_dir, train=True, download=True).data.type(torch.float32)
3939
np.random.seed(1612)
4040
perm = np.random.permutation(mnist_train_val.shape[0])
41-
mnist_train = mnist_train_val[perm[:55000]]
42-
mnist_val = mnist_train_val[perm[55000:]]
41+
mnist_train = mnist_train_val[perm[:55000], 1:, 1:]
42+
mnist_val = mnist_train_val[perm[55000:], 1:, 1:]
43+
mnist_test = mnist_test[:, 1:, 1:]
4344

4445
assert mnist_train.shape[1] == MNISTSResFourierTargetDataModule.IMG_SHAPE
4546
assert mnist_train.shape[2] == MNISTSResFourierTargetDataModule.IMG_SHAPE
@@ -77,7 +78,7 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
7778
return DataLoader(
7879
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
7980
img_shape=MNISTSResFourierTargetDataModule.IMG_SHAPE),
80-
batch_size=1)
81+
batch_size=self.batch_size)
8182

8283

8384
class CelebASResFourierTargetDataModule(LightningDataModule):
@@ -134,4 +135,4 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
134135
return DataLoader(
135136
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
136137
img_shape=self.gt_shape),
137-
batch_size=1)
138+
batch_size=self.batch_size)

fit/datamodules/super_res/SResFCDataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __getitem__(self, item):
3535

3636
img_mag = 2 * (img_mag - self.mag_min) / (self.mag_max - self.mag_min) - 1
3737

38-
img_phi = 2 * img_phi / (2 * np.pi) - 1
38+
img_phi = img_phi / np.pi
3939

4040
img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1)
4141
return img_fft, (self.mag_min.unsqueeze(-1), self.mag_max.unsqueeze(-1))

fit/datamodules/tomo_rec/TRecDataModule.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,100 @@ def setup(self, stage: Optional[str] = None):
208208

209209
self.gt_ds = get_projection_dataset(
210210
GroundTruthDataset(gt_train, gt_val, gt_test),
211-
num_angles=self.num_angles, im_shape=450, impl='astra_cpu', inner_circle=self.inner_circle)
211+
num_angles=self.num_angles, im_shape=self.gt_shape + (self.gt_shape // 2 - 7), impl='astra_cpu',
212+
inner_circle=self.inner_circle)
213+
214+
tmp_fcds = TRecFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
215+
img_shape=self.gt_shape)
216+
self.mag_min = tmp_fcds.mag_min
217+
self.mag_max = tmp_fcds.mag_max
218+
219+
def train_dataloader(self, *args, **kwargs) -> DataLoader:
220+
return DataLoader(
221+
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
222+
img_shape=self.gt_shape),
223+
batch_size=self.batch_size, num_workers=1)
224+
225+
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
226+
return DataLoader(
227+
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
228+
img_shape=self.gt_shape),
229+
batch_size=self.batch_size, num_workers=1)
230+
231+
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
232+
return DataLoader(
233+
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
234+
img_shape=self.gt_shape),
235+
batch_size=1)
236+
237+
238+
class CropLoDoPaBFourierTargetDataModule(LightningDataModule):
239+
IMG_SHAPE = 361
240+
241+
def __init__(self, batch_size, gt_shape=361, num_angles=15):
242+
"""
243+
:param root_dir:
244+
:param batch_size:
245+
:param num_angles:
246+
"""
247+
super().__init__()
248+
self.batch_size = batch_size
249+
self.gt_shape = gt_shape
250+
self.num_angles = num_angles
251+
self.inner_circle = True
252+
self.gt_ds = None
253+
self.mean = None
254+
self.std = None
255+
256+
def setup(self, stage: Optional[str] = None):
257+
lodopab = dival.get_standard_dataset('lodopab', impl='astra_cpu')
258+
assert self.gt_shape <= self.IMG_SHAPE, 'GT is larger than original images.'
259+
if self.gt_shape < self.IMG_SHAPE:
260+
crop_off = (362 - self.gt_shape) // 2
261+
gt_train = np.array([lodopab.get_sample(i, part='train', out=(False, True))[1][crop_off:-(crop_off + 1),
262+
crop_off:-(crop_off + 1)] for i in
263+
range(4000)])
264+
gt_val = np.array([lodopab.get_sample(i, part='validation', out=(False, True))[1][crop_off:-(crop_off + 1),
265+
crop_off:-(crop_off + 1)] for i in
266+
range(400)])
267+
gt_test = np.array([lodopab.get_sample(i, part='test', out=(False, True))[1][crop_off:-(crop_off + 1),
268+
crop_off:-(crop_off + 1)] for i in
269+
range(3553)])
270+
else:
271+
gt_train = np.array(
272+
[lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(4000)])
273+
gt_val = np.array(
274+
[lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(400)])
275+
gt_test = np.array(
276+
[lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:] for i in range(3553)])
277+
278+
gt_train = torch.from_numpy(gt_train)
279+
gt_val = torch.from_numpy(gt_val)
280+
gt_test = torch.from_numpy(gt_test)
281+
282+
assert gt_train.shape[1] == self.gt_shape
283+
assert gt_train.shape[2] == self.gt_shape
284+
x, y = torch.meshgrid(torch.arange(-self.gt_shape // 2 + 1,
285+
self.gt_shape // 2 + 1),
286+
torch.arange(-self.gt_shape // 2 + 1,
287+
self.gt_shape // 2 + 1))
288+
289+
self.mean = gt_train.mean()
290+
self.std = gt_train.std()
291+
292+
gt_train = normalize(gt_train, self.mean, self.std)
293+
gt_val = normalize(gt_val, self.mean, self.std)
294+
gt_test = normalize(gt_test, self.mean, self.std)
295+
296+
circle = torch.sqrt(x ** 2. + y ** 2.) <= self.gt_shape // 2
297+
gt_train *= circle
298+
gt_val *= circle
299+
gt_test *= circle
300+
301+
self.gt_ds = get_projection_dataset(
302+
GroundTruthDataset(gt_train, gt_val, gt_test),
303+
num_angles=self.num_angles, im_shape=self.gt_shape + (self.gt_shape // 2 - 7), impl='astra_cpu',
304+
inner_circle=self.inner_circle)
212305

213306
tmp_fcds = TRecFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
214307
img_shape=self.gt_shape)

fit/datamodules/tomo_rec/TRecFCDataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def __getitem__(self, item):
4343
sino_mag = 2 * (sino_mag - self.mag_min) / (self.mag_max - self.mag_min) - 1
4444
img_mag = 2 * (img_mag - self.mag_min) / (self.mag_max - self.mag_min) - 1
4545

46-
sino_phi = 2 * sino_phi / (2 * np.pi) - 1
47-
img_phi = 2 * img_phi / (2 * np.pi) - 1
46+
sino_phi = sino_phi / np.pi
47+
img_phi = img_phi / np.pi
4848

4949
sino_fft = torch.stack([sino_mag.flatten(), sino_phi.flatten()], dim=-1)
5050
img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1)

0 commit comments

Comments
 (0)