Adversarial training with lightning #14782
-
| Hello! I'm attempting to do some simple adversarial training with lightning but I'm running in some issues for the testing part.         with torch.enable_grad():
            adv_img = self.atk(imgs, labels)This works fine during training (when doing trainer.fit(model), but fails during testing (trainer.test(model)), with RuntimeError: element 0 of tensors does not require grad and does not have a grad_fnI checked similar problems and the solutions was to enable gradients (which I did) or remove automatic optimization but I would like to retain the possibility of accumulating gradients, and since trainer.fit works fine I don't see why I would need to do manual optimization. I did some digging and the enabling of gradients seems to work: Same thing happens if I do trainer.validate(model). Also when using other attacks than PGD. Any idea why is that and how I can fix it? Reproducible in colab Or full script to reproduce: import pytorch_lightning as pl
from torch import nn
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchattacks
class adv_model(pl.LightningModule):
    def __init__(self,
                 model,
                 attack=None,
                 loaders=None,
                 loss_fn=nn.CrossEntropyLoss(),
                 optim="AdamW",
                 clean=False,
                 lr=0.01
                 ):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.loaders = loaders
        self.atk = attack
        self.clean = clean
        self.lr = lr
        if optim is None:
            self.optim = torch.optim.AdamW
        elif optim == "AdamW":
            self.optim = torch.optim.AdamW
        elif optim == "Adam":
            self.optim = torch.optim.Adam
        elif optim == "SGD":
            self.optim = torch.optim.SGD
        else:
            raise ValueError(f"Optim should be in '[AdamW, Adam, SGD]', not {optim}")
    def forward(self, x, clean=None):
        return self.model(x)
    def training_step(self, batch, batch_nb):
        imgs, labels = batch
        if not self.clean:
            imgs = self.atk(imgs, labels)
        logits = self.model(imgs)
        loss = self.loss_fn(logits, labels)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        acc = (logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
        self.log("train_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
        return {"loss": loss, "acc": acc}
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        clean_logits = self.model(imgs)
        clean_loss = self.loss_fn(clean_logits, labels)
        clean_acc = (clean_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
        self.log("clean_val_loss", clean_loss, prog_bar=True)
        self.log('clean_val_acc', clean_acc, prog_bar=True)
        if self.clean:
            return clean_loss, clean_acc
        # computing adversarial accuracy and loss
        with torch.enable_grad():
            adv_img = self.atk(imgs, labels)
        adv_logits = self.model(adv_img)
        adv_loss = self.loss_fn(adv_logits, labels)
        self.log("adv_val_loss", adv_loss, prog_bar=True)
        adv_acc = (adv_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
        self.log('adv_val_acc', adv_acc, prog_bar=True)
        return clean_loss, clean_acc, adv_loss, adv_acc
    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        clean_logits = self.model(imgs)
        clean_loss = self.loss_fn(clean_logits, labels)
        clean_acc = (clean_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
        self.log("clean_test_loss", clean_loss, prog_bar=True)
        self.log('clean_test_acc', clean_acc, prog_bar=True)
        if self.clean:
            return clean_loss, clean_acc
        # computing adversarial accuracy and loss
        with torch.enable_grad():
            adv_img = self.atk(imgs, labels)
        adv_logits = self.model(adv_img)
        adv_loss = self.loss_fn(adv_logits, labels)
        self.log("adv_test_loss", adv_loss, prog_bar=True)
        adv_acc = (adv_logits.argmax(dim=1)).eq(labels).sum().item() / len(imgs)
        self.log('adv_test_acc', adv_acc, prog_bar=True)
        return clean_loss, clean_acc, adv_loss, adv_acc
    def configure_optimizers(self):
        optim = self.optim
        if issubclass(optim, torch.optim.SGD):
            if self.lr is not None:
                return optim(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=1e-4)
            else:
                return optim(self.model.parameters(), momentum=0.9, weight_decay=1e-4)
        elif issubclass(optim, (torch.optim.Adam, torch.optim.AdamW)):
            if self.lr is not None:
                return optim(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
            else:
                return optim(self.model.parameters(), weight_decay=1e-4)
        else:
            return self.optim
    def train_dataloader(self):
        return self.loaders[0]
    def val_dataloader(self):
        return self.loaders[1]
    def test_dataloader(self):
        return self.loaders[2]
trainer = pl.Trainer(accelerator="gpu",
                     max_epochs=3,
                     val_check_interval=1.0,
                     )
base_model = torch.nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
                                 nn.Linear(256, 256), nn.ReLU(),
                                 nn.Linear(256, 10))
train_set = MNIST(root="./",
                  transform=T.ToTensor(),
                  download=True,
                  train=True
                  )
test_set = MNIST(root="./",
                 transform=T.ToTensor(),
                 download=True,
                 train=False
                 )
train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=2)
val_loader = DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=2)
loaders = (train_loader, val_loader, test_loader)
atk = torchattacks.PGD(model=base_model.cuda(), steps=10)
model = adv_model(base_model,
                  loaders=loaders,
                  attack=atk,
                  clean=False,
                  optim="Adam")
trainer.fit(model)
trainer.test(model) | 
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
| Answering to myself: For now the solution I have is to change the function @contextmanager
def _evaluation_context(accelerator: Accelerator) -> Generator:
   # inference mode is not supported with gloo backend (#9431),
   # and HPU & TPU accelerators.
   context_manager_class = (
       torch.inference_mode
       if not (dist.is_initialized() and dist.get_backend() == "gloo")
       and not isinstance(accelerator, HPUAccelerator)
       and not isinstance(accelerator, TPUAccelerator)
       else torch.no_grad
   )
   with context_manager_class():
       yieldto always use torch.no_grad (l2794 in trainer.py). If there is a simple alternative to use in the test_step or some parameters to force the use of no_grad instead of inference_mode I'm all ears. | 
Beta Was this translation helpful? Give feedback.
-
| Could you please let us know if there have been any updates on this issue? @sergedurand | 
Beta Was this translation helpful? Give feedback.
Answering to myself:
After more digging, it seems that it is the use of torch.inference_mode that is the cause of the issue.
Using torch.no_grad is not enough to get out of inference_mode.
In fact getting out of inference_mode with e.g with torch.inference_mode(mode=False) or a decorator is not enough, I then have a problem
with Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
For now the solution I have is to change the function