Skip to content

Poor performance relative to PyTorch #886

@jessebett

Description

@jessebett

In addition to the numerical stability differences between Tracker and Zygote described in #876, Zygote is performing considerably worse than the equivalent pytorch code for that example.

Here is the PyTorch code:

import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset

# dummy data
x = torch.rand(100000,113)
y = torch.sum(x**2,dim=1, keepdim=True)
dataset = TensorDataset(x,y)
dataloader = DataLoader(dataset, batch_size=256)

model = nn.Sequential(
    nn.Linear(113,1000),
    nn.Linear(1000,1)
    )

criterion = nn.BCEWithLogitsLoss()  # binary cross entropy
optimizer = torch.optim.Adam(
    model.parameters(), lr=1e-4, betas=(0.9, 0.99)
    )

model.train() # enable autograd
for (x,y) in dataloader:
  y_hat_logit = model(x)
  loss = criterion(y_hat_logit,y)
  print(loss.float())

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Here is the Flux code:

using Flux
using Statistics: mean
using MLDataUtils

# dummy data
x = Float32.(rand(113,100000))
y = sum(x.^2,dims=1)

dataset = batchview((x,y),size=256)

model = Chain(
              Dense(113,1000),
              Dense(1000,1)
             )

criterion(logits,y) = mean(Flux.logitbinarycrossentropy.(logits,y))
optimizer = Flux.ADAM(1e-4,(0.9,0.99))

for (x,y) in dataset
  θ = params(model)
  loss,back = Flux.Zygote.pullback(()->criterion(model(x),y),θ)
  println(loss)
  grads = back(1.)
  Flux.Optimise.update!(optimizer,θ,grads)
end
  1. The above PyTorch code is much faster than the Flux code.
  2. The Flux code, after a few iterations, results in NaNs, where the PyTorch code does not. Possibly the same issue as Model optimization fails (NaNs) with Zygote.pullback but works with Tracker.forward #876

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions