Skip to content

Commit f9316e3

Browse files
authored
gpu fourierembedding (mathLab#313)
Adding gpu to fourier embedding
1 parent ba79984 commit f9316e3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pina/model/layers/embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def forward(self, x):
246246
:rtype: torch.Tensor
247247
"""
248248
# compute random matrix multiplication
249-
out = torch.mm(x, self._matrix)
249+
out = torch.mm(x, self._matrix.to(device=x.device, dtype=x.dtype))
250250
# return embedding
251251
return torch.cat(
252252
[torch.cos(2 * torch.pi * out), torch.sin(2 * torch.pi * out)],

0 commit comments

Comments
 (0)