-
-
Notifications
You must be signed in to change notification settings - Fork 616
Closed
FluxML/Zygote.jl
#1044Description
In addition to the numerical stability differences between Tracker and Zygote described in #876, Zygote is performing considerably worse than the equivalent pytorch code for that example.
Here is the PyTorch code:
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
# dummy data
x = torch.rand(100000,113)
y = torch.sum(x**2,dim=1, keepdim=True)
dataset = TensorDataset(x,y)
dataloader = DataLoader(dataset, batch_size=256)
model = nn.Sequential(
nn.Linear(113,1000),
nn.Linear(1000,1)
)
criterion = nn.BCEWithLogitsLoss() # binary cross entropy
optimizer = torch.optim.Adam(
model.parameters(), lr=1e-4, betas=(0.9, 0.99)
)
model.train() # enable autograd
for (x,y) in dataloader:
y_hat_logit = model(x)
loss = criterion(y_hat_logit,y)
print(loss.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
Here is the Flux code:
using Flux
using Statistics: mean
using MLDataUtils
# dummy data
x = Float32.(rand(113,100000))
y = sum(x.^2,dims=1)
dataset = batchview((x,y),size=256)
model = Chain(
Dense(113,1000),
Dense(1000,1)
)
criterion(logits,y) = mean(Flux.logitbinarycrossentropy.(logits,y))
optimizer = Flux.ADAM(1e-4,(0.9,0.99))
for (x,y) in dataset
θ = params(model)
loss,back = Flux.Zygote.pullback(()->criterion(model(x),y),θ)
println(loss)
grads = back(1.)
Flux.Optimise.update!(optimizer,θ,grads)
end
- The above PyTorch code is much faster than the Flux code.
- The Flux code, after a few iterations, results in
NaN
s, where the PyTorch code does not. Possibly the same issue as Model optimization fails (NaNs) with Zygote.pullback but works with Tracker.forward #876
XVilka, alisiahkoohi, AStupidBear, yiyuezhuo and ToucheSir
Metadata
Metadata
Assignees
Labels
No labels