diff --git a/.gitignore b/.gitignore index 68bc17f..7b43cca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,160 +1,6 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ +.idea/ +*__pycache__* +.DS_Store *.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +*.xml +venv/ \ No newline at end of file diff --git a/README.md b/README.md index 26bd0d5..d69553d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,27 @@ -# RobustSNNConversion -Adversarially Robust Spiking Neural Networks Through Conversion +# Adversarially Robust Spiking Neural Networks Through Conversion -Repository will be updated soon.. +This is the code repository of the following [paper](https://arxiv.org/pdf/2311.09266.pdf) to perform adversarially robust ANN-to-SNN conversion. + +"Adversarially Robust Spiking Neural Networks Through Conversion"\ +Ozan Ă–zdenizci, Robert Legenstein\ +arXiv preprint arXiv:2311.09266 (2023). + +## Reference +If you use this code or models in your research and find it helpful, please cite the following paper: +``` +@article{ozdenizci2023adversarially, + title={Adversarially robust spiking neural networks through conversion}, + author={Ozan {\"O}zdenizci and Robert Legenstein}, + journal={arXiv preprint arXiv:2311.09266}, + year={2023} +} +``` + +## Acknowledgments + +Authors of this work are affiliated with Graz University of Technology, Institute of Theoretical Computer Science, and Silicon Austria Labs, TU Graz - SAL Dependable Embedded Systems Lab, Graz, Austria. This work has been supported by the "University SAL Labs" initiative of Silicon Austria Labs (SAL) and its Austrian partner universities for applied fundamental research for electronic based systems. + +Parts of this code repository is based on the following works: + +* https://github.com/nitin-rathi/hybrid-snn-conversion +* https://github.com/putshua/SNN-RAT diff --git a/attack/__init__.py b/attack/__init__.py new file mode 100644 index 0000000..6d7c31f --- /dev/null +++ b/attack/__init__.py @@ -0,0 +1,9 @@ +from attack.fgsm import FGSM +from attack.rfgsm import RFGSM +from attack.pgd import PGD +from attack.tpgd import TPGD +from attack.mart import MART +from attack.apgd import APGD +from attack.apgdt import APGDT +from attack.square import Square +from attack.ensemble import Ensemble diff --git a/attack/apgd.py b/attack/apgd.py new file mode 100644 index 0000000..02c94ab --- /dev/null +++ b/attack/apgd.py @@ -0,0 +1,258 @@ +import time +import numpy as np +import torch +import torch.nn as nn +from torchattacks.attack import Attack + + +class APGD(Attack): + def __init__(self, model, fwd_function=None, T=None, surrogate='PCW', gamma=1.0, norm='Linf', eps=8/255, steps=10, + n_restarts=1, seed=0, loss='ce', eot_iter=1, rho=.75, verbose=False): + super().__init__("APGD", model) + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + self.eps = eps + self.steps = steps + self.norm = norm + self.n_restarts = n_restarts + self.seed = seed + self.loss = loss + self.eot_iter = eot_iter + self.thr_decr = rho + self.verbose = verbose + self.supported_mode = ['default'] + print('Auto-PGD attack with epsilon: ', eps, ' and loss: ', loss) + if T > 0: + print('Surrogate: ', surrogate, ' and gamma: ', gamma) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + _, adv_images = self.perturb(images, labels, cheap=True) + return adv_images + + def check_oscillation(self, x, j, k, y5, k3=0.75): + t = np.zeros(x.shape[1]) + for counter5 in range(k): + t += x[j - counter5] > x[j - counter5 - 1] + + return t <= k * k3 * np.ones(t.shape) + + def check_shape(self, x): + return x if len(x.shape) > 0 else np.expand_dims(x, 0) + + def dlr_loss(self, x, y): + x_sorted, ind_sorted = x.sort(dim=1) + ind = (ind_sorted[:, -1] == y).float() + + return -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind)) / ( + x_sorted[:, -1] - x_sorted[:, -3] + 1e-12) + + def attack_single_run(self, x_in, y_in): + x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) + y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) + + self.steps_2, self.steps_min, self.size_decr = max(int(0.22 * self.steps), 1), max(int(0.06 * self.steps), 1), max( + int(0.03 * self.steps), 1) + if self.verbose: + print('parameters: ', self.steps, self.steps_2, self.steps_min, self.size_decr) + + if self.norm == 'Linf': + t = 2 * torch.rand(x.shape).to(self.device).detach() - 1 + x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / ( + t.reshape([t.shape[0], -1]).abs().max(dim=1, keepdim=True)[0].reshape([-1, 1, 1, 1])) + elif self.norm == 'L2': + t = torch.randn(x.shape).to(self.device).detach() + x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / ( + (t ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) + x_adv = x_adv.clamp(0., 1.) + x_best = x_adv.clone() + x_best_adv = x_adv.clone() + loss_steps = torch.zeros([self.steps, x.shape[0]]) + loss_best_steps = torch.zeros([self.steps + 1, x.shape[0]]) + acc_steps = torch.zeros_like(loss_best_steps) + + if self.loss == 'ce': + criterion_indiv = nn.CrossEntropyLoss(reduction='none') + elif self.loss == 'dlr': + criterion_indiv = self.dlr_loss + else: + raise ValueError('unknowkn loss') + + x_adv.requires_grad_() + grad = torch.zeros_like(x) + for _ in range(self.eot_iter): + with torch.enable_grad(): + if self.forward_function is not None: + logits = self.forward_function(self.model, x_adv, self.T, self.surrogate, self.gamma) # 1 forward pass (eot_iter = 1) + else: + logits = self.model(x_adv) # 1 forward pass (eot_iter = 1) + loss_indiv = criterion_indiv(logits, y) + loss = loss_indiv.sum() + + grad += torch.autograd.grad(loss, [x_adv])[0].detach() # 1 backward pass (eot_iter = 1) + + grad /= float(self.eot_iter) + grad_best = grad.clone() + + acc = logits.detach().max(1)[1] == y + acc_steps[0] = acc + 0 + loss_best = loss_indiv.detach().clone() + + step_size = self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0]).to( + self.device).detach().reshape([1, 1, 1, 1]) + x_adv_old = x_adv.clone() + counter = 0 + k = self.steps_2 + 0 + u = np.arange(x.shape[0]) + counter3 = 0 + + loss_best_last_check = loss_best.clone() + reduced_last_check = np.zeros(loss_best.shape) == np.zeros(loss_best.shape) + n_reduced = 0 + + for i in range(self.steps): + ### gradient step + with torch.no_grad(): + x_adv = x_adv.detach() + grad2 = x_adv - x_adv_old + x_adv_old = x_adv.clone() + + a = 0.75 if i > 0 else 1.0 + + if self.norm == 'Linf': + x_adv_1 = x_adv + step_size * torch.sign(grad) + x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0) + x_adv_1 = torch.clamp( + torch.min(torch.max(x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - self.eps), x + self.eps), + 0.0, 1.0) + + elif self.norm == 'L2': + x_adv_1 = x_adv + step_size * grad / ((grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) + x_adv_1 = torch.clamp(x + (x_adv_1 - x) / ( + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min( + self.eps * torch.ones(x.shape).to(self.device).detach(), + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()), 0.0, 1.0) + x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) + x_adv_1 = torch.clamp(x + (x_adv_1 - x) / ( + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min( + self.eps * torch.ones(x.shape).to(self.device).detach(), + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12), 0.0, 1.0) + + x_adv = x_adv_1 + 0. + + ### get gradient + x_adv.requires_grad_() + grad = torch.zeros_like(x) + for _ in range(self.eot_iter): + with torch.enable_grad(): + if self.forward_function is not None: + logits = self.forward_function(self.model, x_adv, self.T, self.surrogate, self.gamma) # 1 forward pass (eot_iter = 1) + else: + logits = self.model(x_adv) # 1 forward pass (eot_iter = 1) + loss_indiv = criterion_indiv(logits, y) + loss = loss_indiv.sum() + + grad += torch.autograd.grad(loss, [x_adv])[0].detach() # 1 backward pass (eot_iter = 1) + + grad /= float(self.eot_iter) + + pred = logits.detach().max(1)[1] == y + acc = torch.min(acc, pred) + acc_steps[i + 1] = acc + 0 + x_best_adv[(pred == 0).nonzero().squeeze()] = x_adv[(pred == 0).nonzero().squeeze()] + 0. + if self.verbose: + print('iteration: {} - Best loss: {:.6f}'.format(i, loss_best.sum())) + + ### check step size + with torch.no_grad(): + y1 = loss_indiv.detach().clone() + loss_steps[i] = y1.cpu() + 0 + ind = (y1 > loss_best).nonzero().squeeze() + x_best[ind] = x_adv[ind].clone() + grad_best[ind] = grad[ind].clone() + loss_best[ind] = y1[ind] + 0 + loss_best_steps[i + 1] = loss_best + 0 + + counter3 += 1 + + if counter3 == k: + fl_oscillation = self.check_oscillation(loss_steps.detach().cpu().numpy(), i, k, + loss_best.detach().cpu().numpy(), k3=self.thr_decr) + fl_reduce_no_impr = (~reduced_last_check) * ( + loss_best_last_check.cpu().numpy() >= loss_best.cpu().numpy()) + fl_oscillation = ~(~fl_oscillation * ~fl_reduce_no_impr) + reduced_last_check = np.copy(fl_oscillation) + loss_best_last_check = loss_best.clone() + + if np.sum(fl_oscillation) > 0: + step_size[u[fl_oscillation]] /= 2.0 + n_reduced = fl_oscillation.astype(float).sum() + + fl_oscillation = np.where(fl_oscillation) + + x_adv[fl_oscillation] = x_best[fl_oscillation].clone() + grad[fl_oscillation] = grad_best[fl_oscillation].clone() + + counter3 = 0 + k = np.maximum(k - self.size_decr, self.steps_min) + + return x_best, acc, loss_best, x_best_adv + + def perturb(self, x_in, y_in, best_loss=False, cheap=True): + assert self.norm in ['Linf', 'L2'] + x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) + y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) + + adv = x.clone() + if self.forward_function is not None: + outs = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + outs = self.model(x) + acc = outs.max(1)[1] == y + loss = -1e10 * torch.ones_like(acc).float() + if self.verbose: + print('-------------------------- running {}-attack with epsilon {:.4f} --------------------------'.format( + self.norm, self.eps)) + print('initial accuracy: {:.2%}'.format(acc.float().mean())) + startt = time.time() + + if not best_loss: + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + if not cheap: + raise ValueError('not implemented yet') + + else: + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() + best_curr, acc_curr, loss_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) + ind_curr = (acc_curr == 0).nonzero().squeeze() + # + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() + if self.verbose: + print('restart {} - robust accuracy: {:.2%} - cum. time: {:.1f} s'.format( + counter, acc.float().mean(), time.time() - startt)) + + return acc, adv + + else: + adv_best = x.detach().clone() + loss_best = torch.ones([x.shape[0]]).to(self.device) * (-float('inf')) + for counter in range(self.n_restarts): + best_curr, _, loss_curr, _ = self.attack_single_run(x, y) + ind_curr = (loss_curr > loss_best).nonzero().squeeze() + adv_best[ind_curr] = best_curr[ind_curr] + 0. + loss_best[ind_curr] = loss_curr[ind_curr] + 0. + + if self.verbose: + print('restart {} - loss: {:.5f}'.format(counter, loss_best.sum())) + + return loss_best, adv_best diff --git a/attack/apgdt.py b/attack/apgdt.py new file mode 100644 index 0000000..1958e1c --- /dev/null +++ b/attack/apgdt.py @@ -0,0 +1,246 @@ +import time +import numpy as np +import torch +import torch.nn as nn +from torchattacks.attack import Attack + + +class APGDT(Attack): + def __init__(self, model, fwd_function=None, T=None, surrogate='PCW', gamma=1.0, norm='Linf', eps=8/255, steps=10, + n_restarts=1, seed=0, eot_iter=1, rho=.75, verbose=False, n_classes=10): + super().__init__("APGDT", model) + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + self.eps = eps + self.steps = steps + self.norm = norm + self.n_restarts = n_restarts + self.seed = seed + self.eot_iter = eot_iter + self.thr_decr = rho + self.verbose = verbose + self.target_class = None + self.n_target_classes = n_classes - 1 + self.supported_mode = ['default'] + print('Auto-PGD-Targeted attack with epsilon: ', eps) + if T > 0: + print('Surrogate: ', surrogate, ' and gamma: ', gamma) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + _, adv_images = self.perturb(images, labels, cheap=True) + return adv_images + + def check_oscillation(self, x, j, k, y5, k3=0.5): + t = np.zeros(x.shape[1]) + for counter5 in range(k): + t += x[j - counter5] > x[j - counter5 - 1] + + return t <= k * k3 * np.ones(t.shape) + + def check_shape(self, x): + return x if len(x.shape) > 0 else np.expand_dims(x, 0) + + def dlr_loss_targeted(self, x, y, y_target): + x_sorted, ind_sorted = x.sort(dim=1) + + return -(x[np.arange(x.shape[0]), y] - x[np.arange(x.shape[0]), y_target]) / ( + x_sorted[:, -1] - .5 * x_sorted[:, -3] - .5 * x_sorted[:, -4] + 1e-12) + + def attack_single_run(self, x_in, y_in): + x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) + y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) + + self.steps_2, self.steps_min, self.size_decr = max(int(0.22 * self.steps), 1), max(int(0.06 * self.steps), + 1), max( + int(0.03 * self.steps), 1) + if self.verbose: + print('parameters: ', self.steps, self.steps_2, self.steps_min, self.size_decr) + + if self.norm == 'Linf': + t = 2 * torch.rand(x.shape).to(self.device).detach() - 1 + x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / ( + t.reshape([t.shape[0], -1]).abs().max(dim=1, keepdim=True)[0].reshape([-1, 1, 1, 1])) + elif self.norm == 'L2': + t = torch.randn(x.shape).to(self.device).detach() + x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / ( + (t ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) + x_adv = x_adv.clamp(0., 1.) + x_best = x_adv.clone() + x_best_adv = x_adv.clone() + loss_steps = torch.zeros([self.steps, x.shape[0]]) + loss_best_steps = torch.zeros([self.steps + 1, x.shape[0]]) + acc_steps = torch.zeros_like(loss_best_steps) + + if self.forward_function is not None: + output = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + output = self.model(x) + y_target = output.sort(dim=1)[1][:, -self.target_class] + + x_adv.requires_grad_() + grad = torch.zeros_like(x) + for _ in range(self.eot_iter): + with torch.enable_grad(): + if self.forward_function is not None: + logits = self.forward_function(self.model, x_adv, self.T, self.surrogate, self.gamma) # 1 forward pass (eot_iter = 1) + else: + logits = self.model(x_adv) # 1 forward pass (eot_iter = 1) + loss_indiv = self.dlr_loss_targeted(logits, y, y_target) + loss = loss_indiv.sum() + + grad += torch.autograd.grad(loss, [x_adv])[0].detach() # 1 backward pass (eot_iter = 1) + + grad /= float(self.eot_iter) + grad_best = grad.clone() + + acc = logits.detach().max(1)[1] == y + acc_steps[0] = acc + 0 + loss_best = loss_indiv.detach().clone() + + step_size = self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0]).to( + self.device).detach().reshape([1, 1, 1, 1]) + x_adv_old = x_adv.clone() + counter = 0 + k = self.steps_2 + 0 + u = np.arange(x.shape[0]) + counter3 = 0 + + loss_best_last_check = loss_best.clone() + reduced_last_check = np.zeros(loss_best.shape) == np.zeros(loss_best.shape) + n_reduced = 0 + + for i in range(self.steps): + ### gradient step + with torch.no_grad(): + x_adv = x_adv.detach() + grad2 = x_adv - x_adv_old + x_adv_old = x_adv.clone() + + a = 0.75 if i > 0 else 1.0 + + if self.norm == 'Linf': + x_adv_1 = x_adv + step_size * torch.sign(grad) + x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0) + x_adv_1 = torch.clamp( + torch.min(torch.max(x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - self.eps), + x + self.eps), 0.0, 1.0) + + elif self.norm == 'L2': + x_adv_1 = x_adv + step_size[0] * grad / ( + (grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) + x_adv_1 = torch.clamp(x + (x_adv_1 - x) / ( + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min( + self.eps * torch.ones(x.shape).to(self.device).detach(), + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()), 0.0, 1.0) + x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) + x_adv_1 = torch.clamp(x + (x_adv_1 - x) / ( + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min( + self.eps * torch.ones(x.shape).to(self.device).detach(), + ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12), 0.0, 1.0) + + x_adv = x_adv_1 + 0. + + ### get gradient + x_adv.requires_grad_() + grad = torch.zeros_like(x) + for _ in range(self.eot_iter): + with torch.enable_grad(): + if self.forward_function is not None: + logits = self.forward_function(self.model, x_adv, self.T, self.surrogate, self.gamma) # 1 forward pass (eot_iter = 1) + else: + logits = self.model(x_adv) # 1 forward pass (eot_iter = 1) + loss_indiv = self.dlr_loss_targeted(logits, y, y_target) + loss = loss_indiv.sum() + + grad += torch.autograd.grad(loss, [x_adv])[0].detach() # 1 backward pass (eot_iter = 1) + + grad /= float(self.eot_iter) + + pred = logits.detach().max(1)[1] == y + acc = torch.min(acc, pred) + acc_steps[i + 1] = acc + 0 + x_best_adv[(pred == 0).nonzero().squeeze()] = x_adv[(pred == 0).nonzero().squeeze()] + 0. + if self.verbose: + print('iteration: {} - Best loss: {:.6f}'.format(i, loss_best.sum())) + + ### check step size + with torch.no_grad(): + y1 = loss_indiv.detach().clone() + loss_steps[i] = y1.cpu() + 0 + ind = (y1 > loss_best).nonzero().squeeze() + x_best[ind] = x_adv[ind].clone() + grad_best[ind] = grad[ind].clone() + loss_best[ind] = y1[ind] + 0 + loss_best_steps[i + 1] = loss_best + 0 + + counter3 += 1 + + if counter3 == k: + fl_oscillation = self.check_oscillation(loss_steps.detach().cpu().numpy(), i, k, + loss_best.detach().cpu().numpy(), k3=self.thr_decr) + fl_reduce_no_impr = (~reduced_last_check) * ( + loss_best_last_check.cpu().numpy() >= loss_best.cpu().numpy()) + fl_oscillation = ~(~fl_oscillation * ~fl_reduce_no_impr) + reduced_last_check = np.copy(fl_oscillation) + loss_best_last_check = loss_best.clone() + + if np.sum(fl_oscillation) > 0: + step_size[u[fl_oscillation]] /= 2.0 + n_reduced = fl_oscillation.astype(float).sum() + + fl_oscillation = np.where(fl_oscillation) + + x_adv[fl_oscillation] = x_best[fl_oscillation].clone() + grad[fl_oscillation] = grad_best[fl_oscillation].clone() + + counter3 = 0 + k = np.maximum(k - self.size_decr, self.steps_min) + + return x_best, acc, loss_best, x_best_adv + + def perturb(self, x_in, y_in, best_loss=False, cheap=True): + assert self.norm in ['Linf', 'L2'] + x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) + y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) + + adv = x.clone() + if self.forward_function is not None: + outs = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + outs = self.model(x) + acc = outs.max(1)[1] == y + loss = -1e10 * torch.ones_like(acc).float() + if self.verbose: + print('-------------------------- running {}-attack with epsilon {:.4f} --------------------------'.format( + self.norm, self.eps)) + print('initial accuracy: {:.2%}'.format(acc.float().mean())) + startt = time.time() + + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + if not cheap: + raise ValueError('not implemented yet') + + else: + for target_class in range(2, self.n_target_classes + 2): + self.target_class = target_class + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() + best_curr, acc_curr, loss_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) + ind_curr = (acc_curr == 0).nonzero().squeeze() + # + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() + if self.verbose: + print('restart {} - target_class {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( + counter, self.target_class, acc.float().mean(), self.eps, time.time() - startt)) + + return acc, adv \ No newline at end of file diff --git a/attack/ensemble.py b/attack/ensemble.py new file mode 100644 index 0000000..615f93b --- /dev/null +++ b/attack/ensemble.py @@ -0,0 +1,300 @@ +import torch +import torch.nn as nn +from torchattacks.attack import Attack +from attack import * + + +class Ensemble(Attack): + def __init__(self, model, fwd_functions=None, eps=8/255, alpha=2/255, steps=10, T=None, version='autoattack', + seed=0, verbose=False, n_classes=10): + super().__init__("Ensemble", model) + if len(fwd_functions) == 1: + self.ff = fwd_functions[0] + else: + self.ff_1, self.ff_2, self.ff_3 = fwd_functions + self.T = T + self.eps = eps + self.verbose = verbose + self._supported_mode = ['default'] + print('Ensemble attack with epsilon: ', eps) + + if version == 'autoattack': # ['apgd-ce', 'apgd-t', 'fab-t', 'square'] + self._multiattack = MultiAttack([ + APGD(model, fwd_function=self.ff, T=T, eps=eps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGDT(model, fwd_function=self.ff, T=T, eps=eps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + FAB(model, fwd_function=self.ff, T=T, eps=eps, seed=seed, verbose=verbose, multi_targeted=True, n_classes=n_classes, n_restarts=1), + Square(model, fwd_function=self.ff, T=T, eps=eps, seed=seed, verbose=verbose, n_queries=5000, n_restarts=1), + ], fwd_function=self.ff, T=T) + + elif version == 'apgd-dlr': + self._multiattack = MultiAttack([ + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=3.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.25, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.25, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=4.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='STE', eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_2, T=T, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + APGD(model, fwd_function=self.ff_3, T=T, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='dlr', n_restarts=1), + ], fwd_function=self.ff_1, T=T) + + elif version == 'apgdt': + self._multiattack = MultiAttack([ + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=3.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.25, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.25, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=4.0, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_1, T=T, surrogate='STE', eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_2, T=T, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + APGDT(model, fwd_function=self.ff_3, T=T, eps=eps, steps=steps, seed=seed, verbose=verbose, n_classes=n_classes, n_restarts=1), + ], fwd_function=self.ff_1, T=T) + + elif version == 'apgd': + self._multiattack = MultiAttack([ + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=3.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.25, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.25, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.5, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=1.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=2.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=4.0, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_1, T=T, surrogate='STE', eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_2, T=T, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + APGD(model, fwd_function=self.ff_3, T=T, eps=eps, steps=steps, seed=seed, verbose=verbose, loss='ce', n_restarts=1), + ], fwd_function=self.ff_1, T=T) + + elif version == 'pgd': + self._multiattack = MultiAttack([ + PGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=1.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=2.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=3.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.5, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.25, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=0.5, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=1.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=2.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=0.5, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=1.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=2.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.25, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.5, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=1.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=2.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=4.0, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_1, T=T, surrogate='STE', eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_2, T=T, eps=eps, alpha=alpha, steps=steps), + PGD(model, fwd_function=self.ff_3, T=T, eps=eps, alpha=alpha, steps=steps), + ], fwd_function=self.ff_1, T=T) + + elif version == 'rfgsm': + self._multiattack = MultiAttack([ + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=1.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=2.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=3.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.5, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.25, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=0.5, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=1.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=2.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=0.5, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=1.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=2.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.25, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.5, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=1.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=2.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=4.0, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_1, T=T, surrogate='STE', eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_2, T=T, eps=eps, alpha=alpha, loss='ce'), + RFGSM(model, fwd_function=self.ff_3, T=T, eps=eps, alpha=alpha, loss='ce'), + ], fwd_function=self.ff_1, T=T) + + elif version == 'fgsm': + self._multiattack = MultiAttack([ + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=1.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=2.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=3.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.5, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='PCW', gamma=0.25, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=0.5, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=1.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP-D', gamma=2.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=0.5, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=1.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='EXP', gamma=2.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.25, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=0.5, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=1.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=2.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='RECT', gamma=4.0, eps=eps), + FGSM(model, fwd_function=self.ff_1, T=T, surrogate='STE', eps=eps), + FGSM(model, fwd_function=self.ff_2, T=T, eps=eps), + FGSM(model, fwd_function=self.ff_3, T=T, eps=eps), + ], fwd_function=self.ff_1, T=T) + + else: + raise ValueError("Not a valid version.") + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + adv_images = self._multiattack(images, labels) + return adv_images + + +class MultiAttack(Attack): + def __init__(self, attacks, fwd_function=None, T=None, surrogate='PCW', gamma=1.0, verbose=False): + super().__init__("MultiAttack", attacks[0].model) + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + self.attacks = attacks + self.verbose = verbose + self.supported_mode = ['default'] + + self.check_validity() + + self._accumulate_multi_atk_records = False + self._multi_atk_records = [0.0] + + def check_validity(self): + if len(self.attacks) < 2: + raise ValueError("More than two attacks should be given.") + + ids = [id(attack.model) for attack in self.attacks] + if len(set(ids)) != 1: + raise ValueError("At least one of attacks is referencing a different model.") + + def forward(self, images, labels): + r""" + Overridden. + """ + batch_size = images.shape[0] + fails = torch.arange(batch_size).to(self.device) + final_images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + + multi_atk_records = [batch_size] + + for _, attack in enumerate(self.attacks): + adv_images = attack(images[fails], labels[fails]) + + if self.forward_function is not None: + outputs = self.forward_function(self.model, adv_images, self.T, self.surrogate, self.gamma) + else: + outputs = self.model(adv_images) + _, pre = torch.max(outputs.data, 1) + + corrects = (pre == labels[fails]) + wrongs = ~corrects + + succeeds = torch.masked_select(fails, wrongs) + succeeds_of_fails = torch.masked_select(torch.arange(fails.shape[0]).to(self.device), wrongs) + + final_images[succeeds] = adv_images[succeeds_of_fails] + + fails = torch.masked_select(fails, corrects) + multi_atk_records.append(len(fails)) + + if len(fails) == 0: + break + + if self.verbose: + print(self._return_sr_record(multi_atk_records)) + + if self._accumulate_multi_atk_records: + self._update_multi_atk_records(multi_atk_records) + + return final_images + + def _clear_multi_atk_records(self): + self._multi_atk_records = [0.0] + + def _covert_to_success_rates(self, multi_atk_records): + sr = [((1-multi_atk_records[i]/multi_atk_records[0])*100) for i in range(1, len(multi_atk_records))] + return sr + + def _return_sr_record(self, multi_atk_records): + sr = self._covert_to_success_rates(multi_atk_records) + return "Attack success rate: "+" | ".join(["%2.2f %%"%item for item in sr]) + + def _update_multi_atk_records(self, multi_atk_records): + for i, item in enumerate(multi_atk_records): + self._multi_atk_records[i] += item + + def save(self, data_loader, save_path=None, verbose=True, return_verbose=False, + save_predictions=False, save_clean_images=False): + r""" + Overridden. + """ + self._clear_multi_atk_records() + prev_verbose = self.verbose + self.verbose = False + self._accumulate_multi_atk_records = True + + for i, attack in enumerate(self.attacks): + self._multi_atk_records.append(0.0) + + if return_verbose: + rob_acc, l2, elapsed_time = super().save(data_loader, save_path, + verbose, return_verbose, + save_predictions, + save_clean_images) + sr = self._covert_to_success_rates(self._multi_atk_records) + elif verbose: + super().save(data_loader, save_path, verbose, + return_verbose, save_predictions, + save_clean_images) + sr = self._covert_to_success_rates(self._multi_atk_records) + else: + super().save(data_loader, save_path, False, + False, save_predictions, + save_clean_images) + + self._clear_multi_atk_records() + self._accumulate_multi_atk_records = False + self.verbose = prev_verbose + + if return_verbose: + return rob_acc, sr, l2, elapsed_time + + def _save_print(self, progress, rob_acc, l2, elapsed_time, end): + r""" + Overridden. + """ + print("- Save progress: %2.2f %% / Robust accuracy: %2.2f %%"%(progress, rob_acc)+\ + " / "+self._return_sr_record(self._multi_atk_records)+\ + ' / L2: %1.5f (%2.3f it/s) \t'%(l2, elapsed_time), end=end) \ No newline at end of file diff --git a/attack/fgsm.py b/attack/fgsm.py new file mode 100644 index 0000000..6da86b6 --- /dev/null +++ b/attack/fgsm.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from torchattacks.attack import Attack + + +class FGSM(Attack): + def __init__(self, model, fwd_function=None, eps=8/255, T=None, surrogate='PCW', gamma=1.0): + super().__init__("FGSM", model) + self.eps = eps + self._supported_mode = ['default', 'targeted'] + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + print('FGSM attack with epsilon: ', eps) + if T > 0: + print('Surrogate: ', surrogate, ' and gamma: ', gamma) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + + if self.targeted: + target_labels = self._get_target_label(images, labels) + + loss = nn.CrossEntropyLoss() + + images.requires_grad = True + if self.forward_function is not None: + outputs = self.forward_function(self.model, images, self.T, self.surrogate, self.gamma) + else: + outputs = self.model(images) + + if self.targeted: + cost = -loss(outputs, target_labels) + else: + cost = loss(outputs, labels) + + grad = torch.autograd.grad(cost, images, retain_graph=False, create_graph=False)[0] + adv_images = images + self.eps*grad.sign() + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + return adv_images diff --git a/attack/mart.py b/attack/mart.py new file mode 100644 index 0000000..7080d9f --- /dev/null +++ b/attack/mart.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchattacks.attack import Attack + + +class MART(Attack): + def __init__(self, model, fwd_function=None, eps=8/255, alpha=2/255, steps=10, T=None, surrogate='PCW', gamma=1.0): + super().__init__("MART", model) + self.eps = eps + self.alpha = alpha + self.steps = steps + self._supported_mode = ['default'] + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + print('MART-PGD attack with epsilon: ', eps, ' and step size: ', alpha) + if T > 0: + print('Surrogate: ', surrogate, ' and gamma: ', gamma) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + + loss = nn.CrossEntropyLoss() + + adv_images = images.clone().detach() + adv_images = adv_images + 0.001 * torch.randn_like(adv_images) + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + + for _ in range(self.steps): + adv_images.requires_grad = True + if self.forward_function is not None: + outputs = self.forward_function(self.model, adv_images, self.T, self.surrogate, self.gamma) + else: + outputs = self.model(adv_images) + + cost = loss(outputs, labels) + grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] + adv_images = adv_images.detach() + self.alpha*grad.sign() + delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) + adv_images = torch.clamp(images + delta, min=0, max=1).detach() + + return adv_images diff --git a/attack/pgd.py b/attack/pgd.py new file mode 100644 index 0000000..51d200e --- /dev/null +++ b/attack/pgd.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torchattacks.attack import Attack + + +class PGD(Attack): + def __init__(self, model, fwd_function=None, eps=8/255, alpha=2/255, steps=10, T=None, surrogate='PCW', gamma=1.0): + super().__init__("PGD", model) + self.eps = eps + self.alpha = alpha + self.steps = steps + self._supported_mode = ['default', 'targeted'] + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + print('PGD attack with epsilon: ', eps, ' and step size: ', alpha) + if T > 0: + print('Surrogate: ', surrogate, ' and gamma: ', gamma) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + + if self.targeted: + target_labels = self._get_target_label(images, labels) + + loss = nn.CrossEntropyLoss() + + adv_images = images.clone().detach() + adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps) + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + + for _ in range(self.steps): + adv_images.requires_grad = True + if self.forward_function is not None: + outputs = self.forward_function(self.model, adv_images, self.T, self.surrogate, self.gamma) + else: + outputs = self.model(adv_images) + + if self.targeted: + cost = -loss(outputs, target_labels) + else: + cost = loss(outputs, labels) + + grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] + adv_images = adv_images.detach() + self.alpha*grad.sign() + delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) + adv_images = torch.clamp(images + delta, min=0, max=1).detach() + + return adv_images diff --git a/attack/rfgsm.py b/attack/rfgsm.py new file mode 100644 index 0000000..03ea3a6 --- /dev/null +++ b/attack/rfgsm.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchattacks.attack import Attack + + +class RFGSM(Attack): + def __init__(self, model, fwd_function=None, T=None, surrogate='PCW', gamma=1.0, eps=8/255, alpha=4/255, loss='kl'): + super().__init__("RFGSM", model) + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + self.eps = eps + self.alpha = alpha + self.loss = loss + self.supported_mode = ['default'] + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + if self.loss == 'kl': + if self.forward_function is not None: + outputs = self.forward_function(self.model, images, self.T, self.surrogate, self.gamma).detach() + else: + outputs = self.model(images).detach() + + criterion_kl = nn.KLDivLoss(size_average=False) + adv_images = images + self.alpha * torch.randn_like(images).sign() + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + adv_images.requires_grad = True + if self.forward_function is not None: + outputs_adv = self.forward_function(self.model, adv_images, self.T, self.surrogate, self.gamma) + else: + outputs_adv = self.model(adv_images) + cost = criterion_kl(F.log_softmax(outputs_adv, dim=1), F.softmax(outputs, dim=1)) + + else: # loss = 'ce' + labels = labels.clone().detach().to(self.device) + + loss = nn.CrossEntropyLoss() + adv_images = images + self.alpha*torch.randn_like(images).sign() + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + adv_images.requires_grad = True + if self.forward_function is not None: + outputs = self.forward_function(self.model, adv_images, self.T, self.surrogate, self.gamma) + else: + outputs = self.model(adv_images) + cost = loss(outputs, labels) + + grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] + adv_images = adv_images.detach() + (self.eps - self.alpha) * grad.sign() + delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) + adv_images = torch.clamp(images + delta, min=0, max=1).detach() + return adv_images diff --git a/attack/square.py b/attack/square.py new file mode 100644 index 0000000..560c6f0 --- /dev/null +++ b/attack/square.py @@ -0,0 +1,423 @@ +import torch +import torch.nn as nn +import time +import math +import torch.nn.functional as F +from torchattacks.attack import Attack + + +class Square(Attack): + def __init__(self, model, fwd_function=None, T=None, surrogate='PCW', gamma=1.0, norm='Linf', eps=8/255, + n_queries=5000, n_restarts=1, p_init=.8, loss='margin', resc_schedule=True, seed=0, verbose=False): + super().__init__("Square", model) + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + self.norm = norm + self.n_queries = n_queries + self.eps = eps + self.p_init = p_init + self.n_restarts = n_restarts + self.seed = seed + self.verbose = verbose + self.loss = loss + self.rescale_schedule = resc_schedule + self.supported_mode = ['default', 'targeted'] + print('SquareAttack with epsilon: ', eps) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + adv_images = self.perturb(images, labels) + + return adv_images + + def margin_and_loss(self, x, y): + """ + :param y: correct labels if untargeted else target labels + """ + if self.forward_function is not None: + logits = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + logits = self.model(x) + xent = F.cross_entropy(logits, y, reduction='none') + u = torch.arange(x.shape[0]) + y_corr = logits[u, y].clone() + logits[u, y] = -float('inf') + y_others = logits.max(dim=-1)[0] + + if not self.targeted: + if self.loss == 'ce': + return y_corr - y_others, -1. * xent + elif self.loss == 'margin': + return y_corr - y_others, y_corr - y_others + else: + if self.loss == 'ce': + return y_others - y_corr, xent + elif self.loss == 'margin': + return y_others - y_corr, y_others - y_corr + + def init_hyperparam(self, x): + assert self.norm in ['Linf', 'L2'] + assert not self.eps is None + assert self.loss in ['ce', 'margin'] + + if self.device is None: + self.device = x.device + self.orig_dim = list(x.shape[1:]) + self.ndims = len(self.orig_dim) + if self.seed is None: + self.seed = time.time() + + def check_shape(self, x): + return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0) + + def random_choice(self, shape): + t = 2 * torch.rand(shape).to(self.device) - 1 + return torch.sign(t) + + def random_int(self, low=0, high=1, shape=[1]): + t = low + (high - low) * torch.rand(shape).to(self.device) + return t.long() + + def normalize_delta(self, x): + if self.norm == 'Linf': + t = x.abs().view(x.shape[0], -1).max(1)[0] + return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) + + elif self.norm == 'L2': + t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) + + def lp_norm(self, x): + if self.norm == 'L2': + t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + return t.view(-1, *([1] * self.ndims)) + + def eta_rectangles(self, x, y): + delta = torch.zeros([x, y]).to(self.device) + x_c, y_c = x // 2 + 1, y // 2 + 1 + + counter2 = [x_c - 1, y_c - 1] + for counter in range(0, max(x_c, y_c)): + delta[max(counter2[0], 0):min(counter2[0] + (2 * counter + 1), x), + max(0, counter2[1]):min(counter2[1] + (2 * counter + 1), y) + ] += 1.0 / (torch.Tensor([counter + 1]).view(1, 1).to( + self.device) ** 2) + counter2[0] -= 1 + counter2[1] -= 1 + + delta /= (delta ** 2).sum(dim=(0, 1), keepdim=True).sqrt() + + return delta + + def eta(self, s): + delta = torch.zeros([s, s]).to(self.device) + delta[:s // 2] = self.eta_rectangles(s // 2, s) + delta[s // 2:] = -1. * self.eta_rectangles(s - s // 2, s) + delta /= (delta ** 2).sum(dim=(0, 1), keepdim=True).sqrt() + if torch.rand([1]) > 0.5: + delta = delta.permute([1, 0]) + + return delta + + def p_selection(self, it): + """ schedule to decrease the parameter p """ + + if self.rescale_schedule: + it = int(it / self.n_queries * 10000) + + if 10 < it <= 50: + p = self.p_init / 2 + elif 50 < it <= 200: + p = self.p_init / 4 + elif 200 < it <= 500: + p = self.p_init / 8 + elif 500 < it <= 1000: + p = self.p_init / 16 + elif 1000 < it <= 2000: + p = self.p_init / 32 + elif 2000 < it <= 4000: + p = self.p_init / 64 + elif 4000 < it <= 6000: + p = self.p_init / 128 + elif 6000 < it <= 8000: + p = self.p_init / 256 + elif 8000 < it: + p = self.p_init / 512 + else: + p = self.p_init + + return p + + def attack_single_run(self, x, y): + with torch.no_grad(): + adv = x.clone() + c, h, w = x.shape[1:] + n_features = c * h * w + n_ex_total = x.shape[0] + + if self.norm == 'Linf': + x_best = torch.clamp(x + self.eps * self.random_choice( + [x.shape[0], c, 1, w]), 0., 1.) + margin_min, loss_min = self.margin_and_loss(x_best, y) + n_queries = torch.ones(x.shape[0]).to(self.device) + s_init = int(math.sqrt(self.p_init * n_features / c)) + + for i_iter in range(self.n_queries): + idx_to_fool = (margin_min > 0.0).nonzero().flatten() + + if len(idx_to_fool) == 0: + break + + x_curr = self.check_shape(x[idx_to_fool]) + x_best_curr = self.check_shape(x_best[idx_to_fool]) + y_curr = y[idx_to_fool] + if len(y_curr.shape) == 0: + y_curr = y_curr.unsqueeze(0) + margin_min_curr = margin_min[idx_to_fool] + loss_min_curr = loss_min[idx_to_fool] + + p = self.p_selection(i_iter) + s = max(int(round(math.sqrt(p * n_features / c))), 1) + vh = self.random_int(0, h - s) + vw = self.random_int(0, w - s) + new_deltas = torch.zeros([c, h, w]).to(self.device) + new_deltas[:, vh:vh + s, vw:vw + s + ] = 2. * self.eps * self.random_choice([c, 1, 1]) + + x_new = x_best_curr + new_deltas + x_new = torch.min(torch.max(x_new, x_curr - self.eps), + x_curr + self.eps) + x_new = torch.clamp(x_new, 0., 1.) + x_new = self.check_shape(x_new) + + margin, loss = self.margin_and_loss(x_new, y_curr) + + # update loss if new loss is better + idx_improved = (loss < loss_min_curr).float() + + loss_min[idx_to_fool] = idx_improved * loss + ( + 1. - idx_improved) * loss_min_curr + + # update margin and x_best if new loss is better + # or misclassification + idx_miscl = (margin <= 0.).float() + idx_improved = torch.max(idx_improved, idx_miscl) + + margin_min[idx_to_fool] = idx_improved * margin + ( + 1. - idx_improved) * margin_min_curr + idx_improved = idx_improved.reshape([-1, + *[1] * len(x.shape[:-1])]) + x_best[idx_to_fool] = idx_improved * x_new + ( + 1. - idx_improved) * x_best_curr + n_queries[idx_to_fool] += 1. + + ind_succ = (margin_min <= 0.).nonzero().squeeze() + if self.verbose and ind_succ.numel() != 0 and (i_iter % 500 == 0): + print('{}'.format(i_iter + 1), + '- success rate={}/{} ({:.2%})'.format( + ind_succ.numel(), n_ex_total, + float(ind_succ.numel()) / n_ex_total), + '- avg # queries={:.1f}'.format( + n_queries[ind_succ].mean().item()), + '- med # queries={:.1f}'.format( + n_queries[ind_succ].median().item()), + '- loss={:.3f}'.format(loss_min.mean())) + + if ind_succ.numel() == n_ex_total: + break + + elif self.norm == 'L2': + delta_init = torch.zeros_like(x) + s = h // 5 + sp_init = (h - s * 5) // 2 + vh = sp_init + 0 + for _ in range(h // s): + vw = sp_init + 0 + for _ in range(w // s): + delta_init[:, :, vh:vh + s, vw:vw + s] += self.eta( + s).view(1, 1, s, s) * self.random_choice( + [x.shape[0], c, 1, 1]) + vw += s + vh += s + + x_best = torch.clamp(x + self.normalize_delta(delta_init + ) * self.eps, 0., 1.) + margin_min, loss_min = self.margin_and_loss(x_best, y) + n_queries = torch.ones(x.shape[0]).to(self.device) + s_init = int(math.sqrt(self.p_init * n_features / c)) + + for i_iter in range(self.n_queries): + idx_to_fool = (margin_min > 0.0).nonzero().flatten() + + if len(idx_to_fool) == 0: + break + + x_curr = self.check_shape(x[idx_to_fool]) + x_best_curr = self.check_shape(x_best[idx_to_fool]) + y_curr = y[idx_to_fool] + if len(y_curr.shape) == 0: + y_curr = y_curr.unsqueeze(0) + margin_min_curr = margin_min[idx_to_fool] + loss_min_curr = loss_min[idx_to_fool] + + delta_curr = x_best_curr - x_curr + p = self.p_selection(i_iter) + s = max(int(round(math.sqrt(p * n_features / c))), 3) + if s % 2 == 0: + s += 1 + + vh = self.random_int(0, h - s) + vw = self.random_int(0, w - s) + new_deltas_mask = torch.zeros_like(x_curr) + new_deltas_mask[:, :, vh:vh + s, vw:vw + s] = 1.0 + norms_window_1 = (delta_curr[:, :, vh:vh + s, vw:vw + s + ] ** 2).sum(dim=(-2, -1), keepdim=True).sqrt() + + vh2 = self.random_int(0, h - s) + vw2 = self.random_int(0, w - s) + new_deltas_mask_2 = torch.zeros_like(x_curr) + new_deltas_mask_2[:, :, vh2:vh2 + s, vw2:vw2 + s] = 1. + + norms_image = self.lp_norm(x_best_curr - x_curr) + mask_image = torch.max(new_deltas_mask, new_deltas_mask_2) + norms_windows = self.lp_norm(delta_curr * mask_image) + + new_deltas = torch.ones([x_curr.shape[0], c, s, s] + ).to(self.device) + new_deltas *= (self.eta(s).view(1, 1, s, s) * + self.random_choice([x_curr.shape[0], c, 1, 1])) + old_deltas = delta_curr[:, :, vh:vh + s, vw:vw + s] / ( + 1e-12 + norms_window_1) + new_deltas += old_deltas + new_deltas = new_deltas / (1e-12 + (new_deltas ** 2).sum( + dim=(-2, -1), keepdim=True).sqrt()) * (torch.max( + (self.eps * torch.ones_like(new_deltas)) ** 2 - + norms_image ** 2, torch.zeros_like(new_deltas)) / + c + norms_windows ** 2).sqrt() + delta_curr[:, :, vh2:vh2 + s, vw2:vw2 + s] = 0. + delta_curr[:, :, vh:vh + s, vw:vw + s] = new_deltas + 0 + + x_new = torch.clamp(x_curr + self.normalize_delta(delta_curr + ) * self.eps, 0., 1.) + x_new = self.check_shape(x_new) + norms_image = self.lp_norm(x_new - x_curr) + + margin, loss = self.margin_and_loss(x_new, y_curr) + + # update loss if new loss is better + idx_improved = (loss < loss_min_curr).float() + + loss_min[idx_to_fool] = idx_improved * loss + ( + 1. - idx_improved) * loss_min_curr + + # update margin and x_best if new loss is better + # or misclassification + idx_miscl = (margin <= 0.).float() + idx_improved = torch.max(idx_improved, idx_miscl) + + margin_min[idx_to_fool] = idx_improved * margin + ( + 1. - idx_improved) * margin_min_curr + idx_improved = idx_improved.reshape([-1, + *[1] * len(x.shape[:-1])]) + x_best[idx_to_fool] = idx_improved * x_new + ( + 1. - idx_improved) * x_best_curr + n_queries[idx_to_fool] += 1. + + ind_succ = (margin_min <= 0.).nonzero().squeeze() + if self.verbose and ind_succ.numel() != 0 and (i_iter % 500 == 0): + print('{}'.format(i_iter + 1), + '- success rate={}/{} ({:.2%})'.format( + ind_succ.numel(), n_ex_total, float( + ind_succ.numel()) / n_ex_total), + '- avg # queries={:.1f}'.format( + n_queries[ind_succ].mean().item()), + '- med # queries={:.1f}'.format( + n_queries[ind_succ].median().item()), + '- loss={:.3f}'.format(loss_min.mean())) + + assert (x_new != x_new).sum() == 0 + assert (x_best != x_best).sum() == 0 + + if ind_succ.numel() == n_ex_total: + break + + return n_queries, x_best + + def perturb(self, x, y=None): + """ + :param x: clean images + :param y: untargeted attack -> clean labels, + if None we use the predicted labels + targeted attack -> target labels, if None random classes, + different from the predicted ones, are sampled + """ + self.init_hyperparam(x) + + adv = x.clone() + if y is None: + if not self.targeted: + with torch.no_grad(): + if self.forward_function is not None: + output = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + output = self.model(x) + y_pred = output.max(1)[1] + y = y_pred.detach().clone().long().to(self.device) + else: + with torch.no_grad(): + y = self.get_target_label(x, None) + else: + if not self.targeted: + y = y.detach().clone().long().to(self.device) + else: + y = self.get_target_label(x, y) + + if not self.targeted: + if self.forward_function is not None: + output = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + output = self.model(x) + acc = output.max(1)[1] == y + else: + if self.forward_function is not None: + output = self.forward_function(self.model, x, self.T, self.surrogate, self.gamma) + else: + output = self.model(x) + acc = output.max(1)[1] != y + + startt = time.time() + + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: + ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool = x[ind_to_fool].clone() + y_to_fool = y[ind_to_fool].clone() + + _, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) + + if self.forward_function is not None: + output_curr = self.forward_function(self.model, adv_curr, self.T, self.surrogate, self.gamma) + else: + output_curr = self.model(adv_curr) + + if not self.targeted: + acc_curr = output_curr.max(1)[1] == y_to_fool + else: + acc_curr = output_curr.max(1)[1] != y_to_fool + ind_curr = (acc_curr == 0).nonzero().squeeze() + + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() + if self.verbose: + print('restart {} - robust accuracy: {:.2%}'.format(counter, acc.float().mean()), + '- cum. time: {:.1f} s'.format(time.time() - startt)) + + return adv diff --git a/attack/tpgd.py b/attack/tpgd.py new file mode 100644 index 0000000..ecbea9a --- /dev/null +++ b/attack/tpgd.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchattacks.attack import Attack + + +class TPGD(Attack): + def __init__(self, model, fwd_function=None, eps=8/255, alpha=2/255, steps=10, T=None, surrogate='PCW', gamma=1.0): + super().__init__("TPGD", model) + self.eps = eps + self.alpha = alpha + self.steps = steps + self._supported_mode = ['default'] + self.forward_function = fwd_function + self.surrogate = surrogate + self.gamma = gamma + self.T = T + print('TRADES-PGD attack with epsilon: ', eps, ' and step size: ', alpha) + if T > 0: + print('Surrogate: ', surrogate, ' and gamma: ', gamma) + + def forward(self, images, labels): + images = images.clone().detach().to(self.device) + if self.forward_function is not None: + outputs = self.forward_function(self.model, images, self.T, self.surrogate, self.gamma).detach() + else: + outputs = self.model(images).detach() + + adv_images = images.clone().detach() + adv_images = adv_images + 0.001 * torch.randn_like(adv_images) + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + + criterion_kl = nn.KLDivLoss(size_average=False) + + for _ in range(self.steps): + adv_images.requires_grad = True + if self.forward_function is not None: + outputs_adv = self.forward_function(self.model, adv_images, self.T, self.surrogate, self.gamma) + else: + outputs_adv = self.model(adv_images) + + cost = criterion_kl(F.log_softmax(outputs_adv, dim=1), F.softmax(outputs, dim=1)) + grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] + adv_images = adv_images.detach() + self.alpha*grad.sign() + delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) + adv_images = torch.clamp(images + delta, min=0, max=1).detach() + + return adv_images diff --git a/data_loaders.py b/data_loaders.py new file mode 100644 index 0000000..8767acb --- /dev/null +++ b/data_loaders.py @@ -0,0 +1,93 @@ +from torchvision import datasets +import torchvision.transforms as transforms +from torchvision.datasets import CIFAR10, CIFAR100, SVHN, ImageFolder +import torch +from torch.utils.data import Dataset, DataLoader +import os +import numpy as np + + +def cifar10(args): + norm = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor()]) + if args.cutout: + transform_train.transforms.append(Cutout(n_holes=1, length=8, norm_mean=norm[0])) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_dataset = CIFAR10(root=os.path.join(args.data_dir, 'cifar10'), train=True, download=True, transform=transform_train) + val_dataset = CIFAR10(root=os.path.join(args.data_dir, 'cifar10'), train=False, download=True, transform=transform_test) + return train_dataset, val_dataset, norm, 10 + + +def cifar100(args): + norm = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) + transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor()]) + if args.cutout: + transform_train.transforms.append(Cutout(n_holes=1, length=8, norm_mean=norm[0])) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_dataset = CIFAR100(root=os.path.join(args.data_dir, 'cifar100'), train=True, download=True, transform=transform_train) + val_dataset = CIFAR100(root=os.path.join(args.data_dir, 'cifar100'), train=False, download=True, transform=transform_test) + return train_dataset, val_dataset, norm, 100 + + +def svhn(args): + norm = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor()]) + if args.cutout: + transform_train.transforms.append(Cutout(n_holes=1, length=8, norm_mean=norm[0])) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_dataset = SVHN(root=os.path.join(args.data_dir, 'SVHN'), split='train', transform=transform_train) + val_dataset = SVHN(root=os.path.join(args.data_dir, 'SVHN'), split='test', transform=transform_test) + return train_dataset, val_dataset, norm, 10 + + +def tinyimagenet(args): + norm = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + transform_train = transforms.Compose([transforms.RandomResizedCrop(64), + transforms.RandomHorizontalFlip(), + transforms.ToTensor()]) + if args.cutout: + transform_train.transforms.append(Cutout(n_holes=1, length=16, norm_mean=norm[0])) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_dataset = ImageFolder(root=os.path.join(args.data_dir, 'tinyimagenet', 'train'), transform=transform_train) + val_dataset = ImageFolder(root=os.path.join(args.data_dir, 'tinyimagenet', 'val'), transform=transform_test) + return train_dataset, val_dataset, norm, 200 + + +class Cutout(object): + def __init__(self, n_holes, length, norm_mean): + self.n_holes = n_holes + self.length = length + self.mean = norm_mean + + def __call__(self, img): + h = img.size(1) + w = img.size(2) + + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + img[0] = img[0] + ((1 - mask[0]) * self.mean[0]) + img[1] = img[1] + ((1 - mask[1]) * self.mean[1]) + img[2] = img[2] + ((1 - mask[2]) * self.mean[2]) + + return img diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..aac9766 --- /dev/null +++ b/main_test.py @@ -0,0 +1,150 @@ +import argparse +import os +import random +import sys +from utils import * +from data_loaders import * +from torchvision import datasets +import numpy as np +import models +import attack +import copy +import torch + +parser = argparse.ArgumentParser(description='Supplementary Code for Adversarially Robust ANN-to-SNN Conversion') +parser.add_argument('--data_dir', default='/DATA_DIR/', type=str, help='dataset directory') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers') +parser.add_argument('-b', '--batch_size', default=64, type=int, metavar='N', help='mini-batch size') +parser.add_argument('-sd', '--seed', default=42, type=int, help='seed for initializing training.') +parser.add_argument('--gpu', default='0', type=str, help='device') +parser.add_argument('-suffix', '--suffix', default='', type=str, help='suffix') +parser.add_argument('--cutout', action='store_true', help='cutout data augmentation') +parser.add_argument('-data', '--dataset', default='cifar10', type=str, help='dataset') +parser.add_argument('-arch', '--model', default='vgg11', type=str, help='model') +parser.add_argument('-T', '--time', default=8, type=int, metavar='N', help='snn simulation time') +parser.add_argument('-id', '--identifier', type=str, help='model statedict identifier to load') +parser.add_argument('--surrogate', default='PCW', type=str, help='surrogate gradient') +parser.add_argument('--gamma', default=1.0, type=float, help='surrogate gradient gamma') +parser.add_argument('--learn_vth', action='store_true', help='perform v_th optimization') +parser.add_argument('--use_bias', action='store_true', help='use bias terms in linear layers') +parser.add_argument('--soft_reset', action='store_true', help='use soft reset after firing') +parser.add_argument('--attack', default='', type=str, help='adversarial attack type') +parser.add_argument('--attack_mode', default='', type=str, help='[bptt, bptr, '']') +parser.add_argument('--eps', default=8, type=float, metavar='N', help='attack eps') +parser.add_argument('--alpha', default=0, type=float, metavar='N', help='pgd attack alpha') +parser.add_argument('--steps', default=10, type=int, metavar='N', help='pgd attack steps') +parser.add_argument('--ens_version', default='autoattack', type=str, help='ensemble attack type') +parser.add_argument('--n_queries', default=5000, type=int, help='number of queries for square attack') +parser.add_argument('--bbmodel', default='', type=str, help='black box model') +args = parser.parse_args() + +os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(): + global args + + model_dir = '%s-checkpoints/%s' % (args.dataset, args.model) + log_dir = '%s-results/%s' % (args.dataset, args.model) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + print(model_dir) + + logger = get_logger(os.path.join(log_dir, '%s.log' % (args.identifier + args.suffix))) + logger.info('start testing!') + + seed = args.seed + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + if args.dataset.lower() == 'cifar10': + _, val_dataset, znorm, num_classes = cifar10(args) + elif args.dataset.lower() == 'cifar100': + _, val_dataset, znorm, num_classes = cifar100(args) + elif args.dataset.lower() == 'svhn': + _, val_dataset, znorm, num_classes = svhn(args) + elif args.dataset.lower() == 'tinyimagenet': + _, val_dataset, znorm, num_classes = tinyimagenet(args) + else: + raise NotImplementedError + + test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + # Create your model + model = models.__dict__[args.model.lower()](args.time, num_classes, znorm, args.learn_vth, args.use_bias, + args.soft_reset, args.surrogate, args.gamma) + model.set_simulation_time(args.time) + model.to(device) + + # have bb model + if len(args.bbmodel) > 0: + bbmodel = copy.deepcopy(model) + bbstate_dict = torch.load(os.path.join(model_dir, args.bbmodel + '.pth'), map_location=torch.device('cpu')) + if (args.time != 0) and ('_T0_' in args.bbmodel): + print('Loaded black-box ANN transfer attack model:') + bbmodel.set_simulation_time(0) + bbmodel.load_state_dict(bbstate_dict, strict=False) + bbmodel.set_simulation_time(0) + acc = val(bbmodel, test_loader, device, 0) + logger.info('Black-box model accuracy: ={:.3f}'.format(acc)) + else: + print('Loaded black-box SNN transfer attack model:') + bbmodel.load_state_dict(bbstate_dict, strict=True) + acc = val(bbmodel, test_loader, device, args.time) + logger.info('Black-box model accuracy: ={:.3f}'.format(acc)) + print(args.bbmodel) + else: + bbmodel = None + + if len(args.bbmodel) > 0: + print('Evaluating as a black-box transfer attack...') + atkmodel = bbmodel + else: + atkmodel = model + + if args.attack_mode == 'bptt': + ff = BPTT_attack + elif args.attack_mode == 'bptr': + ff = BPTR_attack + elif args.attack_mode == 'none': + ff = None + else: + ff = Act_attack + + step_size = 2.5 * args.eps / args.steps if args.alpha == 0 else args.alpha + + if args.attack.lower() == 'fgsm': + atk = attack.FGSM(atkmodel, fwd_function=ff, eps=args.eps/255, T=args.time, surrogate=args.surrogate, gamma=args.gamma) + elif args.attack.lower() == 'rfgsm': + atk = attack.RFGSM(atkmodel, fwd_function=ff, eps=args.eps/255, alpha=step_size/255, T=args.time, surrogate=args.surrogate, gamma=args.gamma) + elif args.attack.lower() == 'pgd': + atk = attack.PGD(atkmodel, fwd_function=ff, eps=args.eps/255, alpha=step_size/255, steps=args.steps, T=args.time, surrogate=args.surrogate, gamma=args.gamma) + elif args.attack.lower() == 'apgd': + atk = attack.APGD(atkmodel, fwd_function=ff, eps=args.eps/255, T=args.time, surrogate=args.surrogate, gamma=args.gamma) + elif args.attack.lower() == 'square': + atk = attack.Square(atkmodel, fwd_function=ff, eps=args.eps/255, T=args.time, n_queries=args.n_queries) + elif args.attack.lower() == 'ensemble': + if args.ens_version == 'autoattack': + atk = attack.Ensemble(atkmodel, fwd_functions=[ff], eps=args.eps/255, T=args.time, n_classes=num_classes) + else: + atk = attack.Ensemble(atkmodel, fwd_functions=[BPTT_attack, BPTR_attack, Act_attack], T=args.time, eps=args.eps/255, alpha=step_size/255, steps=args.steps, version=args.ens_version) + else: + atk = None + + state_dict = torch.load(os.path.join(model_dir, args.identifier + '.pth'), map_location=torch.device('cpu')) + model.load_state_dict(state_dict) + model.to(device) + acc = val(model, test_loader, device, args.time, atk) + logger.info('Attack Test acc={:.3f}'.format(acc)) + + +if __name__ == "__main__": + main() diff --git a/main_train.py b/main_train.py new file mode 100644 index 0000000..272afae --- /dev/null +++ b/main_train.py @@ -0,0 +1,280 @@ +import argparse +import os +import random +import warnings +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import attack +from data_loaders import * +import models +from models import * +from utils import * +from models.WideResNet import BasicBlock +from models.ResNet import BasicResNetBlock +from models.layers import LIFSpike + +parser = argparse.ArgumentParser(description='Supplementary Code for Adversarially Robust ANN-to-SNN Conversion') +parser.add_argument('--data_dir', default='/DATA_DIR/', type=str, help='dataset directory') +parser.add_argument('--workers', '-j', default=4, type=int, metavar='N', help='number of data loading workers') +parser.add_argument('--gpu', default='0', type=str, help='device') +parser.add_argument('--seed', default=42, type=int, help='seed for initializing training. ') +parser.add_argument('--suffix', default='', type=str, help='suffix') +parser.add_argument('--load_weights', type=str, help='ann statedict name to load weights') +parser.add_argument('--scaling_factor', default=0.3, type=float, help='scaling factor for v_th at reduced timesteps') +parser.add_argument('--soft_reset', action='store_true', help='use soft reset after firing') +parser.add_argument('--use_bias', action='store_true', help='use bias terms in linear layers') +parser.add_argument('--learn_vth', action='store_true', help='perform v_th optimization') +parser.add_argument('--surrogate', default='PCW', type=str, help='surrogate gradient') +parser.add_argument('--gamma', default=1.0, type=float, help='surrogate gradient gamma') +parser.add_argument('--batch_size', '-b', default=64, type=int, metavar='N', help='mini-batch size') +parser.add_argument('--optim', default='sgd', type=str, help='adam or sgd') +parser.add_argument('--cutout', action='store_true', help='cutout data augmentation') +parser.add_argument('--dataset', default='cifar10', type=str, help='dataset') +parser.add_argument('--model', default='vgg11_bn', type=str, help='model') +parser.add_argument('--time', '-T', default=8, type=int, metavar='N', help='snn simulation time') +parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') +parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate') +parser.add_argument('--beta', default=5e-4, type=float, help='weight decay parameter') +parser.add_argument('--trades_beta', default=0., type=float, help='TRADES-loss training weight beta') +parser.add_argument('--mart_beta', default=0., type=float, help='MART-loss training weight beta') +parser.add_argument('--attack', default='', type=str, help='adversarial attack type') +parser.add_argument('--attack_mode', default='', type=str, help='[bptt, bptr, '']') +parser.add_argument('--eps', default=2, type=float, metavar='N', help='attack eps') +parser.add_argument('--alpha', default=0, type=float, metavar='N', help='pgd attack alpha') +parser.add_argument('--steps', default=10, type=int, metavar='N', help='pgd attack steps') +args = parser.parse_args() + +os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def find_threshold(model, loader, timesteps, device): + model.set_simulation_time(T=timesteps) + thresholds = [] + + def wrn_find(layer): + max_act = 0 + first_done = False + print('\n Finding threshold for layer {}'.format(layer)) + for batch_idx, (data, target) in enumerate(loader): + data, target = data.to(device), data.to(device) + with torch.no_grad(): + model.eval() + output = model(data, find_max_mem=True, max_mem_layer=layer) + if len(output) > 1: # Customized loop for the BasicBlocks where there are two activations + if first_done: + if output[1] > max_act: + max_act = output[1] + if batch_idx == 20: # use 10 more mini-batches per layer to estimate best thresholds + thresholds.append(max_act) + print(' {}'.format(thresholds)) + model.threshold_update(scaling_factor=1.0, thresholds=thresholds[:]) + break + # Continue setting the first neuron threshold in the BasicBlock first.. + if output[0] > max_act: + max_act = output[0] + if batch_idx == 10: + thresholds.append(max_act) + print(' {}'.format(thresholds)) + model.threshold_update(scaling_factor=1.0, thresholds=thresholds[:]) + first_done = True + max_act = 0 + else: + output = output[0] + if output > max_act: + max_act = output + if batch_idx == 10: + thresholds.append(max_act) + print(' {}'.format(thresholds)) + model.threshold_update(scaling_factor=1.0, thresholds=thresholds[:]) + break + + def vgg_find(layer): + max_act = 0 + print('\n Finding threshold for layer {}'.format(layer)) + for batch_idx, (data, target) in enumerate(loader): + data, target = data.to(device), data.to(device) + with torch.no_grad(): + model.eval() + output = model(data, find_max_mem=True, max_mem_layer=layer) + if output > max_act: + max_act = output + if batch_idx == 10: + thresholds.append(max_act) + print(' {}'.format(thresholds)) + model.threshold_update(scaling_factor=1.0, thresholds=thresholds[:]) + break + + if 'vgg' in args.model.lower(): + for l in model.features.named_children(): + if isinstance(l[1], LIFSpike): + vgg_find(int(l[0])) + + for c in model.classifier.named_children(): + if isinstance(c[1], LIFSpike): + vgg_find(len(model.features) + int(c[0])) + + if 'wrn' in args.model.lower(): + for l in model.features.named_children(): + if isinstance(l[1], BasicBlock): + l[1].find_max_mem = True + + for l in model.features.named_children(): + if int(l[0]) > 1: + if isinstance(l[1], BasicBlock) or isinstance(l[1], nn.BatchNorm2d) or isinstance(l[1], LIFSpike): + wrn_find(int(l[0])) + + for l in model.features.named_children(): + if isinstance(l[1], BasicBlock): + l[1].find_max_mem = False + + if 'resnet' in args.model.lower(): + for l in model.features.named_children(): + if isinstance(l[1], BasicResNetBlock): + l[1].find_max_mem = True + + for l in model.features.named_children(): + if isinstance(l[1], LIFSpike): + wrn_find(int(l[0])) + else: + if (isinstance(l[1], BasicResNetBlock) or isinstance(l[1], nn.AdaptiveAvgPool2d)) and (int(l[0]) > 3): + wrn_find(int(l[0])) + + for l in model.features.named_children(): + if isinstance(l[1], BasicResNetBlock): + l[1].find_max_mem = False + + print('\n ANN thresholds: {}'.format(thresholds)) + return thresholds + + +def main(): + global args + if args.dataset.lower() == 'cifar10': + train_dataset, val_dataset, znorm, num_classes = cifar10(args) + elif args.dataset.lower() == 'cifar100': + train_dataset, val_dataset, znorm, num_classes = cifar100(args) + elif args.dataset.lower() == 'svhn': + train_dataset, val_dataset, znorm, num_classes = svhn(args) + elif args.dataset.lower() == 'tinyimagenet': + train_dataset, val_dataset, znorm, num_classes = tinyimagenet(args) + else: + raise NotImplementedError + + log_dir = '%s-checkpoints/%s' % (args.dataset, args.model) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + print(log_dir) + + seed = args.seed + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True) + test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.workers, pin_memory=True) + + # Create your model + model = models.__dict__[args.model.lower()](args.time, num_classes, znorm, args.learn_vth, args.use_bias, + args.soft_reset, args.surrogate, args.gamma) + + if args.load_weights: + state_dict = torch.load(os.path.join(log_dir, args.load_weights + '.pth'), map_location=torch.device('cpu')) + load_dict = {} + for name, param in state_dict.items(): + if not (('num_batches_tracked' in name) or ('running' in name)): + load_dict[name] = param + missing_keys, unexpected_keys = model.load_state_dict(load_dict, strict=False) + print('\n Missing keys : {}, Unexpected Keys: {}'.format(missing_keys, unexpected_keys)) + + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.running_mean = None + m.running_var = None + m.num_batches_tracked = None + + model.to(device) + thresholds = find_threshold(model, loader=train_loader, timesteps=100, device=device) + model.threshold_update(scaling_factor=args.scaling_factor, thresholds=thresholds[:]) + model.set_simulation_time(args.time) + + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.running_mean = torch.zeros(m.num_features, device=device) + m.running_var = torch.ones(m.num_features, device=device) + m.num_batches_tracked = torch.tensor(0, dtype=torch.long, device=device) + m.reset_running_stats() + + else: + model.set_simulation_time(args.time) + model.to(device) + + if args.attack_mode == 'bptt': + ff = BPTT_attack + elif args.attack_mode == 'bptr': + ff = BPTR_attack + else: + ff = None + + step_size = 2.5 * args.eps / args.steps if args.alpha == 0 else args.alpha + + if args.attack.lower() == 'rfgsm': + adv = attack.RFGSM(model, fwd_function=ff, eps=args.eps / 255, alpha=step_size / 255, loss='kl', T=args.time) + elif args.attack.lower() == 'pgd': + adv = attack.PGD(model, fwd_function=ff, eps=args.eps / 255, alpha=step_size / 255, steps=args.steps, T=args.time) + elif args.attack.lower() == 'tpgd': + adv = attack.TPGD(model, fwd_function=ff, eps=args.eps / 255, alpha=step_size / 255, steps=args.steps, T=args.time) + elif args.attack.lower() == 'mart': + adv = attack.MART(model, fwd_function=ff, eps=args.eps / 255, alpha=step_size / 255, steps=args.steps, T=args.time) + else: + adv = None + assert args.trades_beta == 0. + assert args.mart_beta == 0. + + criterion = nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.beta) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + best_acc = 0 + + if adv is not None: + identifier = '%s[%.3f][%s]' % (adv.__class__.__name__, adv.eps, args.attack_mode) + else: + identifier = 'clean' + + identifier += '_%s[%.4f]_lr[%.4f]_T%d' % ('wd', args.beta, args.lr, args.time) + identifier += args.suffix + + logger = get_logger(os.path.join(log_dir, '%s.log' % (identifier))) + logger.info('start training!') + + if args.load_weights: + logger.info('\n ANN thresholds: {} \n'.format(thresholds)) + pre_calib_acc = val(model, test_loader, device, args.time) + logger.info('Pre-calibration Test acc={:.3f}\n'.format(pre_calib_acc)) + + for epoch in range(args.epochs): + loss, acc = train(model, device, train_loader, criterion, optimizer, args.time, adv_train=adv, + trades_beta=args.trades_beta, mart_beta=args.mart_beta) + logger.info('Epoch:[{}/{}]\t loss={:.5f}\t acc={:.3f}'.format(epoch, args.epochs, loss, acc)) + scheduler.step() + tmp = val(model, test_loader, device, args.time) + logger.info('Epoch:[{}/{}]\t Test acc={:.3f}\n'.format(epoch, args.epochs, tmp)) + + if best_acc < tmp: + best_acc = tmp + torch.save(model.state_dict(), os.path.join(log_dir, '%s.pth' % (identifier))) + + logger.info('Best Test acc={:.3f}'.format(best_acc)) + + +if __name__ == "__main__": + main() diff --git a/models/ResNet.py b/models/ResNet.py new file mode 100644 index 0000000..71fc6f5 --- /dev/null +++ b/models/ResNet.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import models +from models.layers import * + +__all__ = ['ResNet', 'resnet20', 'resnet32'] + + +class BasicResNetBlock(nn.Module): + expansion = 1 + + def __init__(self, T, in_planes, out_planes, stride=1, default_leak=1.0, default_vth=1.0, + learn_vth=False, soft_reset=False, surrogate='PCW', gamma=1.0): + super(BasicResNetBlock, self).__init__() + self.T = T + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_planes) + self.bn2 = nn.BatchNorm2d(out_planes) + self.act1 = LIFSpike(self.T, leak=default_leak, v_th=default_vth, soft_reset=soft_reset, + learn_vth=learn_vth, surrogate=surrogate, gamma=gamma) + self.act2 = LIFSpike(self.T, leak=default_leak, v_th=default_vth, soft_reset=soft_reset, + learn_vth=learn_vth, surrogate=surrogate, gamma=gamma) + self.convex = ConvexCombination(2) + self.find_max_mem = False + self.max_mems = [] + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*out_planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*out_planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*out_planes) + ) + + def forward(self, x): + h1 = self.bn1(self.conv1(x)) + out = self.act1(h1) + out = self.bn2(self.conv2(out)) + out = self.convex(self.shortcut(x), out) + self.max_mems = [h1, out] if self.find_max_mem else [] + out = self.act2(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, T, num_classes, norm, init_channels=3, learn_vth=False, use_bias=False, + soft_reset=False, default_leak=1.0, default_threshold=1.0, surrogate='PCW', gamma=1.0): + super(ResNet, self).__init__() + self.T = T + self.init_channels = init_channels + self.default_leak = default_leak + self.default_vth = default_threshold + self.learn_vth = learn_vth + self.use_bias = use_bias + self.soft = soft_reset + self.surrogate = surrogate + self.gamma = gamma + + self.norm = TensorNormalization(*norm) + self.merge = MergeTemporalDim(T) + self.expand = ExpandTemporalDim(T) + + self.in_planes = 16 + self.features = self._make_layers(block, num_blocks) + self.classifier = self._make_classifier(64 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, val=1) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.T, self.in_planes, planes, stride, default_leak=self.default_leak, + default_vth=self.default_vth, learn_vth=self.learn_vth, soft_reset=self.soft, + surrogate=self.surrogate, gamma=self.gamma)) + self.in_planes = planes * block.expansion + return layers + + def _make_layers(self, block, num_blocks): + layers = [nn.Conv2d(self.init_channels, 16, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(16), + LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma)] + layers.extend(self._make_layer(block, 16, num_blocks[0], stride=1)) + layers.extend(self._make_layer(block, 32, num_blocks[1], stride=2)) + layers.extend(self._make_layer(block, 64, num_blocks[2], stride=2)) + layers.append(nn.AdaptiveAvgPool2d((1, 1))) + return nn.Sequential(*layers) + + def _make_classifier(self, dim_in, dim_out): + layer = [nn.Flatten(), + nn.Linear(dim_in, dim_out, bias=self.use_bias)] + return nn.Sequential(*layer) + + def set_surrogate_gradient(self, surrogate, gamma, mode='bptt'): + for module in self.modules(): + if isinstance(module, LIFSpike): + module.mode = mode + module.surrogate = surrogate + module.gamma = gamma + + def set_simulation_time(self, T, mode='bptt'): + self.T = T + for module in self.modules(): + if isinstance(module, (LIFSpike, ExpandTemporalDim)): + module.T = T + if isinstance(module, LIFSpike): + module.mode = mode + + def threshold_update(self, scaling_factor=1.0, thresholds=[]): + self.scaling_factor = scaling_factor + for pos in range(len(self.features)): + if isinstance(self.features[pos], LIFSpike): + if thresholds: + self.features[pos].v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + if isinstance(self.features[pos], BasicResNetBlock): + if thresholds: + self.features[pos].act1.v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + if thresholds: + self.features[pos].act2.v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + + def percentile(self, t, q=99.7): + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + result = t.view(-1).kthvalue(k).values.item() + return result + + def forward(self, input, find_max_mem=False, max_mem_layer=0, percentile=True): + out = self.norm(input) + if self.T > 0: + out = add_dimension(out, self.T) + out = self.merge(out) + if find_max_mem: + for l in range(len(self.features)): + if l == max_mem_layer: + if isinstance(self.features[l], LIFSpike): + return [self.percentile(out.view(-1))] if percentile else [out.max().item()] + if isinstance(self.features[l], BasicResNetBlock) or isinstance(self.features[l], nn.AdaptiveAvgPool2d): + out1, out2 = self.features[l - 1].max_mems + if percentile: + return [self.percentile(out1.view(-1)), self.percentile(out2.view(-1))] + else: + return [out1.max().item(), out2.max().item()] + out = self.features[l](out) + else: + out = self.features(out) + out = self.classifier(out) + if self.T > 0: + out = self.expand(out) + return out + + +def resnet20(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return ResNet(BasicResNetBlock, [3, 3, 3], T=timesteps, num_classes=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) + + +def resnet32(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return ResNet(BasicResNetBlock, [5, 5, 5], T=timesteps, num_classes=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) diff --git a/models/VGG.py b/models/VGG.py new file mode 100644 index 0000000..a9f88d2 --- /dev/null +++ b/models/VGG.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import models +from models.layers import * + + +__all__ = ['VGG', 'VGG_TIN', 'vgg11_bn', 'vgg16_bn', 'vgg11_tin'] + + +cfg_conv = {'vgg11': [64, 'A', 128, 256, 'A', 512, 512, 512, 'A', 512, 512], + 'vgg16': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512, 'A']} + + +class VGG(nn.Module): + def __init__(self, vgg_name, T, num_class, norm, init_channels=3, use_bias=False, dropout=0.2, default_leak=1.0, + default_threshold=1.0, learn_vth=False, soft_reset=False, surrogate='PCW', gamma=1.0): + super().__init__() + self.vgg_name = vgg_name + self.T = T + self.init_channels = init_channels + self.W = 16 if vgg_name == 'vgg11' else 1 + self.dropout = dropout + self.default_leak = default_leak + self.default_vth = default_threshold + self.learn_vth = learn_vth + self.use_bias = use_bias + self.soft = soft_reset + self.surrogate = surrogate + self.gamma = gamma + + self.norm = TensorNormalization(*norm) + self.merge = MergeTemporalDim(T) + self.expand = ExpandTemporalDim(T) + + self.features = self._make_layers(cfg_conv[self.vgg_name]) + self.classifier = self._make_classifier(num_class) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, val=1) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _make_layers(self, cfg): + layers = [] + for x in cfg: + if x == 'A': + layers.append(nn.AvgPool2d(2)) + else: + layers.append(nn.Conv2d(self.init_channels, x, kernel_size=3, padding=1, bias=self.use_bias)) + layers.append(nn.BatchNorm2d(x)) + layers.append(LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma)) + self.init_channels = x + return nn.Sequential(*layers) + + def _make_classifier(self, num_class): + layer = [nn.Flatten(), + nn.Linear(512 * self.W, 4096, bias=self.use_bias), + LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma), + nn.Dropout(self.dropout), + nn.Linear(4096, 4096, bias=self.use_bias), + LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma), + nn.Dropout(self.dropout), + nn.Linear(4096, num_class, bias=self.use_bias)] + return nn.Sequential(*layer) + + def set_surrogate_gradient(self, surrogate, gamma, mode='bptt'): + for module in self.modules(): + if isinstance(module, LIFSpike): + module.mode = mode + module.surrogate = surrogate + module.gamma = gamma + + def set_simulation_time(self, T, mode='bptt'): + self.T = T + for module in self.modules(): + if isinstance(module, (LIFSpike, ExpandTemporalDim)): + module.T = T + if isinstance(module, LIFSpike): + module.mode = mode + + def threshold_update(self, scaling_factor=1.0, thresholds=[]): + self.scaling_factor = scaling_factor + + for pos in range(len(self.features)): + if isinstance(self.features[pos], LIFSpike): + if thresholds: + self.features[pos].v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + + for pos in range(len(self.classifier)): + if isinstance(self.classifier[pos], LIFSpike): + if thresholds: + self.classifier[pos].v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + + def percentile(self, t, q=99.7): + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + result = t.view(-1).kthvalue(k).values.item() + return result + + def forward(self, input, find_max_mem=False, max_mem_layer=0, percentile=True): + out = self.norm(input) + if self.T > 0: + out = add_dimension(out, self.T) + out = self.merge(out) + if find_max_mem: + for l in range(len(self.features)): + if isinstance(self.features[l], LIFSpike) and l == max_mem_layer: + return self.percentile(out.view(-1)) if percentile else out.max().item() + out = self.features[l](out) + for c in range(len(self.classifier)): + if isinstance(self.classifier[c], LIFSpike) and ((len(self.features) + c) == max_mem_layer): + return self.percentile(out.view(-1)) if percentile else out.max().item() + out = self.classifier[c](out) + else: + out = self.features(out) + out = self.classifier(out) + if self.T > 0: + out = self.expand(out) + return out + + +class VGG_TIN(nn.Module): + def __init__(self, vgg_name, T, num_class, norm, init_channels=3, use_bias=False, default_leak=1.0, + default_threshold=1.0, learn_vth=False, soft_reset=False, surrogate='PCW', gamma=1.0): + super().__init__() + self.vgg_name = vgg_name + self.T = T + self.init_channels = init_channels + self.W = 1 + self.default_leak = default_leak + self.default_vth = default_threshold + self.learn_vth = learn_vth + self.use_bias = use_bias + self.soft = soft_reset + self.surrogate = surrogate + self.gamma = gamma + + self.norm = TensorNormalization(*norm) + self.merge = MergeTemporalDim(T) + self.expand = ExpandTemporalDim(T) + + self.features = self._make_layers([64, 'A', 128, 'A', 256, 256, 'A', 512, 512, 'A', 512, 512]) + self.classifier = self._make_classifier(num_class) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, val=1) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _make_layers(self, cfg): + layers = [] + for x in cfg: + if x == 'A': + layers.append(nn.AvgPool2d(2)) + else: + layers.append(nn.Conv2d(self.init_channels, x, kernel_size=3, padding=1, bias=self.use_bias)) + layers.append(nn.BatchNorm2d(x)) + layers.append(LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma)) + self.init_channels = x + return nn.Sequential(*layers) + + def _make_classifier(self, num_class): + layer = [nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(512 * self.W, 4096, bias=self.use_bias), + LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma), + nn.Linear(4096, num_class, bias=self.use_bias)] + return nn.Sequential(*layer) + + def set_surrogate_gradient(self, surrogate, gamma, mode='bptt'): + for module in self.modules(): + if isinstance(module, LIFSpike): + module.mode = mode + module.surrogate = surrogate + module.gamma = gamma + + def set_simulation_time(self, T, mode='bptt'): + self.T = T + for module in self.modules(): + if isinstance(module, (LIFSpike, ExpandTemporalDim)): + module.T = T + if isinstance(module, LIFSpike): + module.mode = mode + + def threshold_update(self, scaling_factor=1.0, thresholds=[]): + self.scaling_factor = scaling_factor + + for pos in range(len(self.features)): + if isinstance(self.features[pos], LIFSpike): + if thresholds: + self.features[pos].v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + + for pos in range(len(self.classifier)): + if isinstance(self.classifier[pos], LIFSpike): + if thresholds: + self.classifier[pos].v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + + def percentile(self, t, q=99.7): + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + result = t.view(-1).kthvalue(k).values.item() + return result + + def forward(self, input, find_max_mem=False, max_mem_layer=0, percentile=True): + out = self.norm(input) + if self.T > 0: + out = add_dimension(out, self.T) + out = self.merge(out) + if find_max_mem: + for l in range(len(self.features)): + if isinstance(self.features[l], LIFSpike) and l == max_mem_layer: + return self.percentile(out.view(-1)) if percentile else out.max().item() + out = self.features[l](out) + for c in range(len(self.classifier)): + if isinstance(self.classifier[c], LIFSpike) and ((len(self.features) + c) == max_mem_layer): + return self.percentile(out.view(-1)) if percentile else out.max().item() + out = self.classifier[c](out) + else: + out = self.features(out) + out = self.classifier(out) + if self.T > 0: + out = self.expand(out) + return out + + +def vgg11_tin(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return VGG_TIN(vgg_name='vgg11_tin', T=timesteps, num_class=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) + + +def vgg11_bn(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return VGG(vgg_name='vgg11', T=timesteps, num_class=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) + + +def vgg16_bn(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return VGG(vgg_name='vgg16', T=timesteps, num_class=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) diff --git a/models/WideResNet.py b/models/WideResNet.py new file mode 100644 index 0000000..ec737db --- /dev/null +++ b/models/WideResNet.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import models +from models.layers import * + + +__all__ = ['WideResNet', 'wrn16_4', 'wrn28_2', 'wrn28_4'] + + +class BasicBlock(nn.Module): + def __init__(self, T, in_planes, out_planes, stride, dropout=0.3, default_leak=1.0, default_vth=1.0, + learn_vth=False, soft_reset=False, surrogate='PCW', gamma=1.0): + super(BasicBlock, self).__init__() + self.T = T + self.bn1 = nn.BatchNorm2d(in_planes) + self.act1 = LIFSpike(self.T, leak=default_leak, v_th=default_vth, soft_reset=soft_reset, + learn_vth=learn_vth, surrogate=surrogate, gamma=gamma) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_planes) + self.act2 = LIFSpike(self.T, leak=default_leak, v_th=default_vth, soft_reset=soft_reset, + learn_vth=learn_vth, surrogate=surrogate, gamma=gamma) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) + self.dropout = dropout + self.equalInOut = (in_planes == out_planes) + self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None + self.convex = ConvexCombination(2) + self.find_max_mem = False + self.max_mems = [] + + def forward(self, x): + h1 = self.bn1(x) + if not self.equalInOut: + x = self.act1(h1) + else: + out = self.act1(h1) + h2 = self.bn2(self.conv1(out if self.equalInOut else x)) + self.max_mems = [h1, h2] if self.find_max_mem else [] + out = self.act2(h2) + if self.dropout > 0: + out = F.dropout(out, p=self.dropout, training=self.training) + out = self.conv2(out) + return self.convex(x if self.equalInOut else self.convShortcut(x), out) + + +class WideResNet(nn.Module): + def __init__(self, depth, widen_factor, T, num_classes, norm, init_channels=3, dropout=0.0, learn_vth=False, + use_bias=False, soft_reset=False, default_leak=1.0, default_threshold=1.0, surrogate='PCW', gamma=1.0): + super(WideResNet, self).__init__() + assert((depth - 4) % 6 == 0) + n = (depth - 4) / 6 + nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] + block = BasicBlock + self.T = T + self.init_channels = init_channels + self.dropout = dropout + self.default_leak = default_leak + self.default_vth = default_threshold + self.learn_vth = learn_vth + self.use_bias = use_bias + self.soft = soft_reset + self.surrogate = surrogate + self.gamma = gamma + + self.norm = TensorNormalization(*norm) + self.merge = MergeTemporalDim(T) + self.expand = ExpandTemporalDim(T) + + self.features = self._make_layers(block, n, nChannels) + self.classifier = self._make_classifier(nChannels[3], num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, val=1) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _make_block(self, block, in_planes, out_planes, nb_layers, stride, dropout): + layers = [] + for i in range(int(nb_layers)): + layers.append(block(self.T, i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout, + default_leak=self.default_leak, default_vth=self.default_vth, learn_vth=self.learn_vth, + soft_reset=self.soft, surrogate=self.surrogate, gamma=self.gamma)) + return layers + + def _make_layers(self, block, n, nChannels): + layers = [nn.Conv2d(self.init_channels, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)] + layers.extend(self._make_block(block, nChannels[0], nChannels[1], n, 1, self.dropout)) + layers.extend(self._make_block(block, nChannels[1], nChannels[2], n, 2, self.dropout)) + layers.extend(self._make_block(block, nChannels[2], nChannels[3], n, 2, self.dropout)) + layers.append(nn.BatchNorm2d(nChannels[3])) + layers.append(LIFSpike(self.T, leak=self.default_leak, v_th=self.default_vth, soft_reset=self.soft, + learn_vth=self.learn_vth, surrogate=self.surrogate, gamma=self.gamma)) + layers.append(nn.AvgPool2d(8)) + return nn.Sequential(*layers) + + def _make_classifier(self, dim_in, dim_out): + layer = [nn.Flatten(), + nn.Linear(dim_in, dim_out, bias=self.use_bias)] + return nn.Sequential(*layer) + + def set_surrogate_gradient(self, surrogate, gamma, mode='bptt'): + for module in self.modules(): + if isinstance(module, LIFSpike): + module.mode = mode + module.surrogate = surrogate + module.gamma = gamma + + def set_simulation_time(self, T, mode='bptt'): + self.T = T + for module in self.modules(): + if isinstance(module, (LIFSpike, ExpandTemporalDim)): + module.T = T + if isinstance(module, LIFSpike): + module.mode = mode + + def threshold_update(self, scaling_factor=1.0, thresholds=[]): + self.scaling_factor = scaling_factor + for pos in range(len(self.features)): + if isinstance(self.features[pos], BasicBlock): + if thresholds: + self.features[pos].act1.v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + if thresholds: + self.features[pos].act2.v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + if isinstance(self.features[pos], LIFSpike): + if thresholds: + self.features[pos].v_th = nn.Parameter(torch.tensor(thresholds.pop(0) * self.scaling_factor)) + + def percentile(self, t, q=99.7): + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + result = t.view(-1).kthvalue(k).values.item() + return result + + def forward(self, input, find_max_mem=False, max_mem_layer=0, percentile=True): + out = self.norm(input) + if self.T > 0: + out = add_dimension(out, self.T) + out = self.merge(out) + if find_max_mem: + for l in range(len(self.features)): + if l == max_mem_layer: + if isinstance(self.features[l], BasicBlock) or isinstance(self.features[l], nn.BatchNorm2d): + out1, out2 = self.features[l - 1].max_mems + if percentile: + return [self.percentile(out1.view(-1)), self.percentile(out2.view(-1))] + else: + return [out1.max().item(), out2.max().item()] + if isinstance(self.features[l], LIFSpike): + return [self.percentile(out.view(-1))] if percentile else [out.max().item()] + out = self.features[l](out) + else: + out = self.features(out) + out = self.classifier(out) + if self.T > 0: + out = self.expand(out) + return out + + +def wrn16_4(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return WideResNet(depth=16, widen_factor=4, T=timesteps, num_classes=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) + + +def wrn28_2(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return WideResNet(depth=28, widen_factor=2, T=timesteps, num_classes=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) + + +def wrn28_4(timesteps, num_classes, norm, learn_vth=False, use_bias=False, soft_reset=False, + surrogate='PCW', gamma=1.0, **kwargs): + return WideResNet(depth=28, widen_factor=4, T=timesteps, num_classes=num_classes, norm=norm, learn_vth=learn_vth, + use_bias=use_bias, soft_reset=soft_reset, surrogate=surrogate, gamma=gamma, **kwargs) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..fa2090f --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,3 @@ +from .VGG import * +from .ResNet import * +from .WideResNet import * diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 0000000..e74f256 --- /dev/null +++ b/models/layers.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import models +import numpy as np + + +def add_dimension(x, T): + x.unsqueeze_(1) + x = x.repeat(T, 1, 1, 1, 1) + return x + + +class TensorNormalization(nn.Module): + def __init__(self, mean, std): + super(TensorNormalization, self).__init__() + if not isinstance(mean, torch.Tensor): + mean = torch.tensor(mean) + if not isinstance(std, torch.Tensor): + std = torch.tensor(std) + self.mean = mean + self.std = std + + def normalizex(self, tensor, mean, std): + mean = mean[None, :, None, None] + std = std[None, :, None, None] + if mean.device != tensor.device: + mean = mean.to(tensor.device) + std = std.to(tensor.device) + return tensor.sub(mean).div(std) + + def forward(self, X): + return self.normalizex(X, self.mean, self.std) + + +class MergeTemporalDim(nn.Module): + def __init__(self, T): + super().__init__() + self.T = T + + def forward(self, x_seq: torch.Tensor): + return x_seq.flatten(0, 1).contiguous() + + +class ExpandTemporalDim(nn.Module): + def __init__(self, T): + super().__init__() + self.T = T + + def forward(self, x_seq: torch.Tensor): + y_shape = [self.T, int(x_seq.shape[0] / self.T)] + y_shape.extend(x_seq.shape[1:]) + return x_seq.view(y_shape) + + +class ConvexCombination(nn.Module): + def __init__(self, n): + super().__init__() + self.n = n + self.comb = nn.Parameter(torch.ones(n) / n) + + def forward(self, *args): + assert (len(args) == self.n) + out = 0. + for i in range(self.n): + out += args[i] * self.comb[i] + return out + + +class STE(torch.autograd.Function): + @staticmethod + def forward(ctx, input, gamma): + out = torch.zeros_like(input).cuda() + out[input >= 0] = 1.0 + return out + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + return grad_input, None + + +class ExpSpike(torch.autograd.Function): + @staticmethod + def forward(ctx, input, gamma): + alpha = torch.tensor([gamma[0]]) + beta = torch.tensor([gamma[1]]) + ctx.save_for_backward(input, alpha, beta) + out = torch.zeros_like(input).cuda() + out[input >= 0] = 1.0 + return out + + @staticmethod + def backward(ctx, grad_output): + (input, alpha, beta) = ctx.saved_tensors + alpha = alpha[0].item() + beta = beta[0].item() + grad_input = grad_output.clone() + grad = alpha * torch.exp(-beta * torch.abs(input)) + return grad * grad_input, None + + +class Rectangular(torch.autograd.Function): + @staticmethod + def forward(ctx, input, gamma): + out = (input >= 0).float() + L = torch.tensor([gamma]) + ctx.save_for_backward(input, out, L) + return out + + @staticmethod + def backward(ctx, grad_output): + (input, out, others) = ctx.saved_tensors + gamma = others[0].item() + grad_input = grad_output.clone() + grad = (input.abs() < gamma/2).float() / gamma + return grad_input * grad, None + + +class PCW(torch.autograd.Function): + @staticmethod + def forward(ctx, input, gamma): + out = (input >= 0).float() + L = torch.tensor([gamma]) + ctx.save_for_backward(input, out, L) + return out + + @staticmethod + def backward(ctx, grad_output): + (input, out, others) = ctx.saved_tensors + gamma = others[0].item() + grad_input = grad_output.clone() + grad = (1 / gamma) * (1 / gamma) * ((gamma - input.abs()).clamp(min=0)) + return grad_input * grad, None + + +class RateBp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, params): + leak, v_th, soft_reset = params + mem = 0. + spike_pot = [] + T = x.shape[0] + for t in range(T): + mem = mem * leak + x[t, ...] + spike = ((mem - v_th) >= 0).float() + mem = mem - spike * v_th if soft_reset else (1 - spike) * mem + spike_pot.append(spike) + out = torch.stack(spike_pot, dim=0) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + out = out.mean(0).unsqueeze(0) + grad_input = grad_output * (out > 0).float() + return grad_input, None + + +class LIFSpike(nn.Module): + def __init__(self, T, leak=1.0, v_th=1.0, soft_reset=False, learn_vth=False, surrogate='PCW', gamma=1.0): + super(LIFSpike, self).__init__() + self.T = T + self.expand = ExpandTemporalDim(T) + self.merge = MergeTemporalDim(T) + self.mode = 'bptt' + self.surrogate = surrogate + self.gamma = gamma + self.act_pcw = PCW.apply + self.act_exp = ExpSpike.apply + self.act_rect = Rectangular.apply + self.act_ste = STE.apply + self.ratebp = RateBp.apply + self.relu = nn.ReLU(inplace=True) + self.soft_reset = soft_reset + self.leak_mem = leak + self.learn_vth = learn_vth + self.v_th = nn.Parameter(torch.tensor(v_th)) if learn_vth else v_th + + def forward(self, x): + if self.learn_vth: + self.v_th.data.clamp_(min=0.03) # set minimum of v_th=0.03 just in case + if self.mode == 'bptr' and self.T > 0: + x = self.expand(x) + x = self.ratebp(x, (self.leak_mem, self.v_th, self.soft_reset)) + x = self.merge(x) + elif self.T > 0: + x = self.expand(x) + v_mem = 0 + spike_pot = [] + for t in range(self.T): + v_mem = v_mem * self.leak_mem + x[t, ...] + + if self.surrogate == 'PCW': + spike = self.act_pcw(v_mem - self.v_th, self.gamma) + elif self.surrogate == 'EXP': + spike = self.act_exp(v_mem - self.v_th, (1.0, self.gamma)) + elif self.surrogate == 'EXP-D': + spike = self.act_exp(v_mem - self.v_th, (0.3, self.gamma)) + elif self.surrogate == 'RECT': + spike = self.act_rect(v_mem - self.v_th, self.gamma) + elif self.surrogate == 'STE': + spike = self.act_ste(v_mem - self.v_th, self.gamma) + else: + raise NotImplementedError + + v_mem = v_mem - spike * self.v_th if self.soft_reset else (1 - spike) * v_mem + spike_pot.append(spike) + x = torch.stack(spike_pot, dim=0) + x = self.merge(x) + else: + x = self.relu(x) + return x diff --git a/sample_scripts.sh b/sample_scripts.sh new file mode 100644 index 0000000..b217a0d --- /dev/null +++ b/sample_scripts.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +dir="/DATA_DIR/" + +# Example scripts to adversarially train baseline ANNs +python -u main_train.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 0 --cutout --attack 'pgd' --eps 2 +python -u main_train.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 0 --cutout --attack 'tpgd' --trades_beta 6. --eps 2 +python -u main_train.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 0 --cutout --attack 'mart' --mart_beta 4. --eps 2 + +# Example scripts for adversarially robust ANN-to-SNN conversion +python -u main_train.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 8 --beta 0.001 --lr 0.001 --learn_vth --use_bias --attack 'rfgsm' --attack_mode 'bptt' --eps 2 --alpha 1 --load_weights 'PGD[0.008][]_wd[0.0005]_lr[0.1000]_T0' --epochs 60 --trades_beta 2. --suffix '_CONV_PGD[0.008]' +python -u main_train.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 8 --beta 0.001 --lr 0.001 --learn_vth --use_bias --attack 'rfgsm' --attack_mode 'bptt' --eps 2 --alpha 1 --load_weights 'TPGD[0.008][]_wd[0.0005]_lr[0.1000]_T0' --epochs 60 --trades_beta 2. --suffix '_CONV_TPGD[0.008]' +python -u main_train.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 8 --beta 0.001 --lr 0.001 --learn_vth --use_bias --attack 'rfgsm' --attack_mode 'bptt' --eps 2 --alpha 1 --load_weights 'MART[0.008][]_wd[0.0005]_lr[0.1000]_T0' --epochs 60 --trades_beta 2. --suffix '_CONV_MART[0.008]' + +# Example scripts for evaluating converted SNNs +python -u main_test.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 8 --identifier 'RFGSM[0.008][bptt]_wd[0.0010]_lr[0.0010]_T8_CONV_PGD[0.008]' --learn_vth --use_bias --eps 8 --attack ensemble --ens_version fgsm --attack_mode bptt +python -u main_test.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 8 --identifier 'RFGSM[0.008][bptt]_wd[0.0010]_lr[0.0010]_T8_CONV_TPGD[0.008]' --learn_vth --use_bias --eps 8 --attack ensemble --ens_version pgd --attack_mode bptt --steps 20 +python -u main_test.py --data_dir $dir --dataset cifar10 --model vgg11_bn -T 8 --identifier 'RFGSM[0.008][bptt]_wd[0.0010]_lr[0.0010]_T8_CONV_MART[0.008]' --learn_vth --use_bias --eps 8 --attack ensemble --ens_version apgd --attack_mode bptt diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e2c3363 --- /dev/null +++ b/utils.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.layers import * +import logging + +criterion_kl = nn.KLDivLoss(size_average=False) +mart_kl = nn.KLDivLoss(reduction='none') + + +def BPTT_attack(model, image, T, surrogate='PCW', gamma=1.0): + model.set_simulation_time(T, mode='bptt') + model.set_surrogate_gradient(surrogate=surrogate, gamma=gamma, mode='bptt') + output = model(image).mean(0) + return output + + +def BPTR_attack(model, image, T, surrogate='PCW', gamma=1.0): + model.set_simulation_time(T, mode='bptr') + output = model(image).mean(0) + model.set_simulation_time(T) + return output + + +def Act_attack(model, image, T, surrogate='PCW', gamma=1.0): + model.set_simulation_time(0) + output = model(image) + model.set_simulation_time(T) + return output + + +def val(model, test_loader, device, T, adv_train=None): + correct = 0 + total = 0 + model.eval() + for batch_idx, (inputs, targets) in enumerate(test_loader): + inputs = inputs.to(device) + if adv_train is not None: + adv_train.set_model_training_mode(model_training=False, + batchnorm_training=False, + dropout_training=False) + inputs = adv_train(inputs, targets.to(device)) + model.set_simulation_time(T) + + with torch.no_grad(): + outputs = model(inputs).mean(0) if T > 0 else model(inputs) + + _, predicted = outputs.cpu().max(1) + total += float(targets.size(0)) + correct += float(predicted.eq(targets).sum().item()) + final_acc = 100 * correct / total + return final_acc + + +def train(model, device, train_loader, criterion, optimizer, T, adv_train, trades_beta=0., mart_beta=0.): + running_loss = 0 + model.train() + total = 0 + correct = 0 + for i, data in enumerate(train_loader): + images, labels = data[0].to(device), data[1].to(device) + batch_size = images.shape[0] + optimizer.zero_grad() + + if trades_beta != 0. or mart_beta != 0.: + outputs_clean = model(images).mean(0) if T > 0 else model(images) + loss_natural = criterion(outputs_clean, labels) + + if adv_train is not None: + adv_train.set_model_training_mode(model_training=False, + batchnorm_training=False, + dropout_training=False) + images_adv = adv_train(images, labels) + outputs = model(images_adv).mean(0) if T > 0 else model(images_adv) + else: + outputs = model(images).mean(0) if T > 0 else model(images) + + if trades_beta != 0.: + loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(outputs, dim=1), F.softmax(outputs_clean, dim=1)) + loss = loss_natural + trades_beta * loss_robust + else: + if mart_beta != 0.: + adv_probs = F.softmax(outputs, dim=1) + tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] + new_y = torch.where(tmp1[:, -1] == labels, tmp1[:, -2], tmp1[:, -1]) + loss_adv = F.cross_entropy(outputs, labels) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) + nat_probs = F.softmax(outputs_clean, dim=1) + true_probs = torch.gather(nat_probs, 1, (labels.unsqueeze(1)).long()).squeeze() + loss_robust = (1.0 / batch_size) * torch.sum( + torch.sum(mart_kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) + loss = loss_adv + float(mart_beta) * loss_robust + else: + loss = criterion(outputs, labels) + + running_loss += loss.item() + loss.mean().backward() + optimizer.step() + + total += float(labels.size(0)) + _, predicted = outputs.cpu().max(1) + correct += float(predicted.eq(labels.cpu()).sum().item()) + return running_loss, 100 * correct / total + + +def get_logger(filename, verbosity=1, name=None): + level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} + formatter = logging.Formatter("[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s") + logger = logging.getLogger(name) + logger.setLevel(level_dict[verbosity]) + fh = logging.FileHandler(filename, "w") + fh.setFormatter(formatter) + logger.addHandler(fh) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + logger.addHandler(sh) + return logger