Skip to content

Commit 99a2f4d

Browse files
committed
torch_dce module
1 parent a9cd8be commit 99a2f4d

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

sigpy/nn/torch_dce.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
3+
import numpy as np
4+
import torch.nn as nn
5+
import torch.optim as optim
6+
7+
from sigpy.mri import dce
8+
9+
# %%
10+
class DCE(nn.Module):
11+
def __init__(self,
12+
ishape,
13+
sample_time,
14+
R1 = 1.,
15+
M0 = 5.,
16+
R1CA = 4.39,
17+
FA = 15.,
18+
TR = 0.006):
19+
super(DCE, self).__init__()
20+
21+
self.ishape = list(ishape)
22+
23+
self.sample_time = torch.tensor(np.squeeze(sample_time), dtype=torch.float32)
24+
25+
self.R1 = torch.tensor(np.array(R1), dtype=torch.float32)
26+
self.M0 = torch.tensor(np.array(M0), dtype=torch.float32)
27+
self.R1CA = torch.tensor(np.array(R1CA), dtype=torch.float32)
28+
self.FA = torch.tensor(np.array(FA), dtype=torch.float32)
29+
self.TR = torch.tensor(np.array(TR), dtype=torch.float32)
30+
31+
self.FA_radian = self.FA * np.pi / 180.
32+
self.M0_trans = self.M0 * torch.sin(self.FA_radian)
33+
34+
E1 = torch.exp(-self.TR * self.R1)
35+
self.M_steady = self.M0_trans * (1 - E1) / (1 - E1 * torch.cos(self.FA_radian))
36+
37+
Cp = dce.arterial_input_function(sample_time)
38+
self.Cp = torch.tensor(Cp, dtype=torch.float32)
39+
40+
def _check_ishape(self, input):
41+
for i1, i2 in zip(input.shape, self.ishape):
42+
if i1 != i2:
43+
raise ValueError(
44+
'input shape mismatch for {s}, got {input_shape}'.format(s=self, input_shape=input.shape))
45+
46+
def _param_to_conc(self, x):
47+
t1_idx = torch.nonzero(self.sample_time)
48+
t1 = self.sample_time[t1_idx]
49+
dt = torch.diff(t1, dim=0)
50+
K_time = torch.cumsum(self.Cp, dim=0) * dt[-1]
51+
52+
mult = torch.stack((K_time, self.Cp), 1)
53+
54+
xr = torch.reshape(x, (self.ishape[0], np.prod(self.ishape[1:])))
55+
56+
yr = torch.matmul(mult, xr)
57+
58+
oshape = [len(self.sample_time)] + self.ishape[1:]
59+
yr = torch.reshape(yr, tuple(oshape))
60+
61+
return yr
62+
63+
def forward(self, x):
64+
65+
if torch.is_tensor(x) is not True:
66+
x = torch.tensor(x, dtype=torch.float32)
67+
68+
self._check_ishape(x)
69+
70+
# parameters (k_trans, v_p) to concentration
71+
CA = self._param_to_conc(x)
72+
x0 = CA[0, ...] # baseline image
73+
74+
# concentration to MR signal
75+
E1CA = torch.exp(-self.TR * (self.R1 + self.R1CA * CA))
76+
77+
CA_trans = self.M0_trans * (1 - E1CA) / (1 - E1CA * torch.cos(self.FA_radian))
78+
79+
y = CA_trans + x0 - self.M_steady
80+
81+
return y
82+
83+
# %%
84+
if torch.cuda.is_available():
85+
device = "cuda:0"
86+
else:
87+
device = "cpu"
88+
89+
model = DCE()
90+
91+
92+
# for epoch in range(20):

0 commit comments

Comments
 (0)