Skip to content

Commit 90be723

Browse files
committed
rotary needs to be done with full precision to be safe
1 parent bca88e9 commit 90be723

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.6.8',
9+
version = '1.6.9',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/rvt.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import torch
44
from torch import nn, einsum
55
import torch.nn.functional as F
6+
from torch.cuda.amp import autocast
67

78
from einops import rearrange, repeat
89
from einops.layers.torch import Rearrange
910

1011
# rotary embeddings
1112

13+
@autocast(enabled = False)
1214
def rotate_every_two(x):
1315
x = rearrange(x, '... (d j) -> ... d j', j = 2)
1416
x1, x2 = x.unbind(dim = -1)
@@ -22,6 +24,7 @@ def __init__(self, dim, max_freq = 10):
2224
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
2325
self.register_buffer('scales', scales)
2426

27+
@autocast(enabled = False)
2528
def forward(self, x):
2629
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
2730

0 commit comments

Comments
 (0)