diff --git a/README.md b/README.md index 15997f1..2d1376a 100644 --- a/README.md +++ b/README.md @@ -1,265 +1,32 @@ -# AutoAttack +# AutoAttackFix +Towards Understanding the Robustness of Diffusion-Based Purification: A Stochastic Perspective, -"Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks"\ -*Francesco Croce*, *Matthias Hein*\ -ICML 2020\ -[https://arxiv.org/abs/2003.01690](https://arxiv.org/abs/2003.01690) +Yiming Liu, Kezhao Liu, Yao Xiao, Ziyi Dong, Xiaogang Xu, Pengxu Wei, Liang Lin +ICLR 2025 -We propose to use an ensemble of four diverse attacks to reliably evaluate robustness: -+ **APGD-CE**, our new step size-free version of PGD on the cross-entropy, -+ **APGD-DLR**, our new step size-free version of PGD on the new DLR loss, -+ **FAB**, which minimizes the norm of the adversarial perturbations [(Croce & Hein, 2019)](https://arxiv.org/abs/1907.02044), -+ **Square Attack**, a query-efficient black-box attack [(Andriushchenko et al, 2019)](https://arxiv.org/abs/1912.00049). +modified from [https://github.com/fra31/auto-attack](https://github.com/fra31/auto-attack) -**Note**: we fix all the hyperparameters of the attacks, so no tuning is required to test every new classifier. +--- -## News -+ [Sep 2021] - + We add [automatic checks](https://github.com/fra31/auto-attack/blob/master/flags_doc.md) for potential cases where the standard version of AA might be non suitable or sufficient for robustness evaluation. - + The evaluations of models on CIFAR-10 and CIFAR-100 are no longer maintained. Up-to-date leaderboards are available in [RobustBench](https://robustbench.github.io/). -+ [Mar 2021] A version of AutoAttack wrt L1, which includes the extensions of APGD and Square Attack [(Croce & Hein, 2021)](https://arxiv.org/abs/2103.01208), is available! -+ [Oct 2020] AutoAttack is used as standard evaluation in the new benchmark [RobustBench](https://robustbench.github.io/), which includes a [Model Zoo](https://github.com/RobustBench/robustbench) of the most robust classifiers! Note that this page and RobustBench's leaderboards are maintained simultaneously. -+ [Aug 2020] - + **Updated version**: in order to *i)* scale AutoAttack (AA) to datasets with many classes and *ii)* have a faster and more accurate evaluation, we use APGD-DLR and FAB with their *targeted* versions. - + We add the evaluation of models on CIFAR-100 wrt Linf and CIFAR-10 wrt L2. -+ [Jul 2020] A short version of the paper is accepted at [ICML'20 UDL workshop](https://sites.google.com/view/udlworkshop2020/) for a spotlight presentation! -+ [Jun 2020] The paper is accepted at ICML 2020! +AutoAttack underperforms on models with stochastic outputs because its default sample-selection strategy fails to account for randomness. AutoAttack is an ensemble of different attack methods, it determines whether a sample is adversarial based on a single evaluation, ignoring output variability. To address this, we change the selection to a evaluation of 20 iterations, and selecting the adversarial example that results in the lowest accuracy. This modification improves the success rate by up to 10 percentage points when evaluating Diffusion-based purification models. -# Adversarial Defenses Evaluation -We here list adversarial defenses, for many threat models, recently proposed and evaluated with the standard version of -**AutoAttack (AA)**, including -+ *untargeted APGD-CE* (no restarts), -+ *targeted APGD-DLR* (9 target classes), -+ *targeted FAB* (9 target classes), -+ *Square Attack* (5000 queries). +--- -See below for the more expensive AutoAttack+ (AA+) and more options. +**Citations:** -We report the source of the model, i.e. if it is publicly *available*, if we received it from the *authors* or if we *retrained* it, the architecture, the clean accuracy and the reported robust accuracy (note that might be calculated on a subset of the test set or on different models trained with the same defense). The robust accuracy for AA is on the full test set. - -We plan to add new models as they appear and are made available. Feel free to suggest new defenses to test! - -**To have a model added**: please check [here](https://github.com/fra31/auto-attack/issues/new/choose). - -**Checkpoints**: many of the evaluated models are available and easily accessible at this [Model Zoo](https://github.com/RobustBench/robustbench). - -## CIFAR-10 - Linf -The robust accuracy is evaluated at `eps = 8/255`, except for those marked with * for which `eps = 0.031`, where `eps` is the maximal Linf-norm allowed for the adversarial perturbations. The `eps` used is the same set in the original papers.\ -**Note**: ‡ indicates models which exploit additional data for training (e.g. unlabeled data, pre-training). - -**Update**: this is no longer maintained, but an up-to-date leaderboard is available in [RobustBench](https://robustbench.github.io/). - -|# |paper |model |architecture |clean |report. |AA | -|:---:|---|:---:|:---:|---:|---:|---:| -|**1**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)‡| *available*| WRN-70-16| 91.10| 65.87| 65.88| -|**2**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)‡| *available*| WRN-28-10| 89.48| 62.76| 62.80| -|**3**| [(Wu et al., 2020a)](https://arxiv.org/abs/2010.01279)‡| *available*| WRN-34-15| 87.67| 60.65| 60.65| -|**4**| [(Wu et al., 2020b)](https://arxiv.org/abs/2004.05884)‡| *available*| WRN-28-10| 88.25| 60.04| 60.04| -|**5**| [(Carmon et al., 2019)](https://arxiv.org/abs/1905.13736)‡| *available*| WRN-28-10| 89.69| 62.5| 59.53| -|**6**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)| *available*| WRN-70-16| 85.29| 57.14| 57.20| -|**7**| [(Sehwag et al., 2020)](https://github.com/fra31/auto-attack/issues/7)‡| *available*| WRN-28-10| 88.98| -| 57.14| -|**8**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)| *available*| WRN-34-20| 85.64| 56.82| 56.86| -|**9**| [(Wang et al., 2020)](https://openreview.net/forum?id=rklOg6EFwS)‡| *available*| WRN-28-10| 87.50| 65.04| 56.29| -|**10**| [(Wu et al., 2020b)](https://arxiv.org/abs/2004.05884)| *available*| WRN-34-10| 85.36| 56.17| 56.17| -|**11**| [(Alayrac et al., 2019)](https://arxiv.org/abs/1905.13725)‡| *available*| WRN-106-8| 86.46| 56.30| 56.03| -|**12**| [(Hendrycks et al., 2019)](https://arxiv.org/abs/1901.09960)‡| *available*| WRN-28-10| 87.11| 57.4| 54.92| -|**13**| [(Pang et al., 2020c)](https://arxiv.org/abs/2010.00467)| *available*| WRN-34-20| 86.43| 54.39| 54.39| -|**14**| [(Pang et al., 2020b)](https://arxiv.org/abs/2002.08619)| *available*| WRN-34-20| 85.14| -| 53.74| -|**15**| [(Cui et al., 2020)](https://arxiv.org/abs/2011.11164)\*| *available*| WRN-34-20| 88.70| 53.57| 53.57| -|**16**| [(Zhang et al., 2020b)](https://arxiv.org/abs/2002.11242)| *available*| WRN-34-10| 84.52| 54.36| 53.51| -|**17**| [(Rice et al., 2020)](https://arxiv.org/abs/2002.11569)| *available*| WRN-34-20| 85.34| 58| 53.42| -|**18**| [(Huang et al., 2020)](https://arxiv.org/abs/2002.10319)\*| *available*| WRN-34-10| 83.48| 58.03| 53.34| -|**19**| [(Zhang et al., 2019b)](https://arxiv.org/abs/1901.08573)\*| *available*| WRN-34-10| 84.92| 56.43| 53.08| -|**20**| [(Cui et al., 2020)](https://arxiv.org/abs/2011.11164)\*| *available*| WRN-34-10| 88.22| 52.86| 52.86| -|**21**| [(Qin et al., 2019)](https://arxiv.org/abs/1907.02610v2)| *available*| WRN-40-8| 86.28| 52.81| 52.84| -|**22**| [(Chen et al., 2020a)](https://arxiv.org/abs/2003.12862)| *available*| RN-50 (x3)| 86.04| 54.64| 51.56| -|**23**| [(Chen et al., 2020b)](https://github.com/fra31/auto-attack/issues/26)| *available*| WRN-34-10| 85.32| 51.13| 51.12| -|**24**| [(Sitawarin et al., 2020)](https://github.com/fra31/auto-attack/issues/23)| *available*| WRN-34-10| 86.84| 50.72| 50.72| -|**25**| [(Engstrom et al., 2019)](https://github.com/MadryLab/robustness)| *available*| RN-50| 87.03| 53.29| 49.25| -|**26**| [(Kumari et al., 2019)](https://arxiv.org/abs/1905.05186)| *available*| WRN-34-10| 87.80| 53.04| 49.12| -|**27**| [(Mao et al., 2019)](http://papers.nips.cc/paper/8339-metric-learning-for-adversarial-robustness)| *available*| WRN-34-10| 86.21| 50.03| 47.41| -|**28**| [(Zhang et al., 2019a)](https://arxiv.org/abs/1905.00877)| *retrained*| WRN-34-10| 87.20| 47.98| 44.83| -|**29**| [(Madry et al., 2018)](https://arxiv.org/abs/1706.06083)| *available*| WRN-34-10| 87.14| 47.04| 44.04| -|**30**| [(Pang et al., 2020a)](https://arxiv.org/abs/1905.10626)| *available*| RN-32| 80.89| 55.0| 43.48| -|**31**| [(Wong et al., 2020)](https://arxiv.org/abs/2001.03994)| *available*| RN-18| 83.34| 46.06| 43.21| -|**32**| [(Shafahi et al., 2019)](https://arxiv.org/abs/1904.12843)| *available*| WRN-34-10| 86.11| 46.19| 41.47| -|**33**| [(Ding et al., 2020)](https://openreview.net/forum?id=HkeryxBtPB)| *available*| WRN-28-4| 84.36| 47.18| 41.44| -|**34**| [(Atzmon et al., 2019)](https://arxiv.org/abs/1905.11911)\*| *available*| RN-18| 81.30| 43.17| 40.22| -|**35**| [(Moosavi-Dezfooli et al., 2019)](http://openaccess.thecvf.com/content_CVPR_2019/html/Moosavi-Dezfooli_Robustness_via_Curvature_Regularization_and_Vice_Versa_CVPR_2019_paper)| *authors*| WRN-28-10| 83.11| 41.4| 38.50| -|**36**| [(Zhang & Wang, 2019)](http://papers.nips.cc/paper/8459-defense-against-adversarial-attacks-using-feature-scattering-based-adversarial-training)| *available*| WRN-28-10| 89.98| 60.6| 36.64| -|**37**| [(Zhang & Xu, 2020)](https://openreview.net/forum?id=Syejj0NYvr¬eId=Syejj0NYvr)| *available*| WRN-28-10| 90.25| 68.7| 36.45| -|**38**| [(Jang et al., 2019)](http://openaccess.thecvf.com/content_ICCV_2019/html/Jang_Adversarial_Defense_via_Learning_to_Generate_Diverse_Attacks_ICCV_2019_paper.html)| *available*| RN-20| 78.91| 37.40| 34.95| -|**39**| [(Kim & Wang, 2020)](https://openreview.net/forum?id=rJlf_RVKwr)| *available*| WRN-34-10| 91.51| 57.23| 34.22| -|**40**| [(Wang & Zhang, 2019)](http://openaccess.thecvf.com/content_ICCV_2019/html/Wang_Bilateral_Adversarial_Training_Towards_Fast_Training_of_More_Robust_Models_ICCV_2019_paper.html)| *available*| WRN-28-10| 92.80| 58.6| 29.35| -|**41**| [(Xiao et al., 2020)](https://arxiv.org/abs/1905.10510)\*| *available*| DenseNet-121| 79.28| 52.4| 18.50| -|**42**| [(Jin & Rinard, 2020)](https://arxiv.org/abs/2003.04286v1) | [*available*](https://github.com/charlesjin/adversarial_regularization/blob/6a3704757dcc7c707ff38f8b9de6f2e9e27e0a89/pretrained/pretrained88.pth) | RN-18| 90.84| 71.22| 1.35| -|**43**| [(Mustafa et al., 2019)](https://arxiv.org/abs/1904.00887)| *available*| RN-110| 89.16| 32.32| 0.28| -|**44**| [(Chan et al., 2020)](https://arxiv.org/abs/1912.10185)| *retrained*| WRN-34-10| 93.79| 15.5| 0.26| - -## CIFAR-100 - Linf -The robust accuracy is computed at `eps = 8/255` in the Linf-norm, except for the models marked with * for which `eps = 0.031` is used. \ -**Note**: ‡ indicates models which exploit additional data for training (e.g. unlabeled data, pre-training).\ -\ -**Update**: this is no longer maintained, but an up-to-date leaderboard is available in [RobustBench](https://robustbench.github.io/). - -|# |paper |model |architecture |clean |report. |AA | -|:---:|---|:---:|:---:|---:|---:|---:| -|**1**| [(Gowal et al. 2020)](https://arxiv.org/abs/2010.03593)‡| *available*| WRN-70-16| 69.15| 37.70| 36.88| -|**2**| [(Cui et al., 2020)](https://arxiv.org/abs/2011.11164)\*| *available*| WRN-34-20| 62.55| 30.20| 30.20| -|**3**| [(Gowal et al. 2020)](https://arxiv.org/abs/2010.03593)| *available*| WRN-70-16| 60.86| 30.67| 30.03| -|**4**| [(Cui et al., 2020)](https://arxiv.org/abs/2011.11164)\*| *available*| WRN-34-10| 60.64| 29.33| 29.33| -|**5**| [(Wu et al., 2020b)](https://arxiv.org/abs/2004.05884)| *available*| WRN-34-10| 60.38| 28.86| 28.86| -|**6**| [(Hendrycks et al., 2019)](https://arxiv.org/abs/1901.09960)‡| *available*| WRN-28-10| 59.23| 33.5| 28.42| -|**7**| [(Cui et al., 2020)](https://arxiv.org/abs/2011.11164)\*| *available*| WRN-34-10| 70.25| 27.16| 27.16| -|**8**| [(Chen et al., 2020b)](https://github.com/fra31/auto-attack/issues/26)| *available*| WRN-34-10| 62.15| -| 26.94| -|**9**| [(Sitawarin et al., 2020)](https://github.com/fra31/auto-attack/issues/22)| *available*| WRN-34-10| 62.82| 24.57| 24.57| -|**10**| [(Rice et al., 2020)](https://arxiv.org/abs/2002.11569)| *available*| RN-18| 53.83| 28.1| 18.95| - -## MNIST - Linf -The robust accuracy is computed at `eps = 0.3` in the Linf-norm. - -|# |paper |model |clean |report. |AA | -|:---:|---|:---:|---:|---:|---:| -|**1**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)| *available*| 99.26| 96.38| 96.34| -|**2**| [(Zhang et al., 2020a)](https://arxiv.org/abs/1906.06316)| *available*| 98.38| 96.38| 93.96| -|**3**| [(Gowal et al., 2019)](https://arxiv.org/abs/1810.12715)| *available*| 98.34| 93.78| 92.83| -|**4**| [(Zhang et al., 2019b)](https://arxiv.org/abs/1901.08573)| *available*| 99.48| 95.60| 92.81| -|**5**| [(Ding et al., 2020)](https://openreview.net/forum?id=HkeryxBtPB)| *available*| 98.95| 92.59| 91.40| -|**6**| [(Atzmon et al., 2019)](https://arxiv.org/abs/1905.11911)| *available*| 99.35| 97.35| 90.85| -|**7**| [(Madry et al., 2018)](https://arxiv.org/abs/1706.06083)| *available*| 98.53| 89.62| 88.50| -|**8**| [(Jang et al., 2019)](http://openaccess.thecvf.com/content_ICCV_2019/html/Jang_Adversarial_Defense_via_Learning_to_Generate_Diverse_Attacks_ICCV_2019_paper.html)| *available*| 98.47| 94.61| 87.99| -|**9**| [(Wong et al., 2020)](https://arxiv.org/abs/2001.03994)| *available*| 98.50| 88.77| 82.93| -|**10**| [(Taghanaki et al., 2019)](http://openaccess.thecvf.com/content_CVPR_2019/html/Taghanaki_A_Kernelized_Manifold_Mapping_to_Diminish_the_Effect_of_Adversarial_CVPR_2019_paper.html)| *retrained*| 98.86| 64.25| 0.00| - -## CIFAR-10 - L2 -The robust accuracy is computed at `eps = 0.5` in the L2-norm.\ -**Note**: ‡ indicates models which exploit additional data for training (e.g. unlabeled data, pre-training). - -**Update**: this is no longer maintained, but an up-to-date leaderboard is available in [RobustBench](https://robustbench.github.io/). - -|# |paper |model |architecture |clean |report. |AA | -|:---:|---|:---:|:---:|---:|---:|---:| -|**1**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)‡| *available*| WRN-70-16| 94.74| -| 80.53| -|**2**| [(Gowal et al., 2020)](https://arxiv.org/abs/2010.03593)| *available*| WRN-70-16| 90.90| -| 74.50| -|**3**| [(Wu et al., 2020b)](https://arxiv.org/abs/2004.05884)| *available*| WRN-34-10| 88.51| 73.66| 73.66| -|**4**| [(Augustin et al., 2020)](https://arxiv.org/abs/2003.09461)‡| *authors*| RN-50| 91.08| 73.27| 72.91| -|**5**| [(Engstrom et al., 2019)](https://github.com/MadryLab/robustness)| *available*| RN-50| 90.83| 70.11| 69.24| -|**6**| [(Rice et al., 2020)](https://arxiv.org/abs/2002.11569)| *available*| RN-18| 88.67| 71.6| 67.68| -|**7**| [(Rony et al., 2019)](https://arxiv.org/abs/1811.09600)| *available*| WRN-28-10| 89.05| 67.6| 66.44| -|**8**| [(Ding et al., 2020)](https://openreview.net/forum?id=HkeryxBtPB)| *available*| WRN-28-4| 88.02| 66.18| 66.09| - -# How to use AutoAttack - -### Installation - -``` -pip install git+https://github.com/fra31/auto-attack -``` - -### PyTorch models -Import and initialize AutoAttack with - -```python -from autoattack import AutoAttack -adversary = AutoAttack(forward_pass, norm='Linf', eps=epsilon, version='standard') -``` - -where: -+ `forward_pass` returns the logits and takes input with components in [0, 1] (NCHW format expected), -+ `norm = ['Linf' | 'L2' | 'L1']` is the norm of the threat model, -+ `eps` is the bound on the norm of the adversarial perturbations, -+ `version = 'standard'` uses the standard version of AA. - -To apply the standard evaluation, where the attacks are run sequentially on batches of size `bs` of `images`, use - -```python -x_adv = adversary.run_standard_evaluation(images, labels, bs=batch_size) -``` - -To run the attacks individually, use - -```python -dict_adv = adversary.run_standard_evaluation_individual(images, labels, bs=batch_size) +AutoAttackFix ``` - -which returns a dictionary with the adversarial examples found by each attack. - -To specify a subset of attacks add e.g. `adversary.attacks_to_run = ['apgd-ce']`. - -### TensorFlow models -To evaluate models implemented in TensorFlow 1.X, use - -```python -from autoattack import utils_tf -model_adapted = utils_tf.ModelAdapter(logits, x_input, y_input, sess) - -from autoattack import AutoAttack -adversary = AutoAttack(model_adapted, norm='Linf', eps=epsilon, version='standard', is_tf_model=True) -``` - -where: -+ `logits` is the tensor with the logits given by the model, -+ `x_input` is a placeholder for the input for the classifier (NHWC format expected), -+ `y_input` is a placeholder for the correct labels, -+ `sess` is a TF session. - -If TensorFlow's version is 2.X, use - -```python -from autoattack import utils_tf2 -model_adapted = utils_tf2.ModelAdapter(tf_model) - -from autoattack import AutoAttack -adversary = AutoAttack(model_adapted, norm='Linf', eps=epsilon, version='standard', is_tf_model=True) -``` - -where: -+ `tf_model` is tf.keras model without activation function 'softmax' - -The evaluation can be run in the same way as done with PT models. - -### Examples -Examples of how to use AutoAttack can be found in `examples/`. To run the standard evaluation on a pretrained -PyTorch model on CIFAR-10 use -``` -python eval.py [--individual] --version=['standard' | 'plus'] +@inproceedings{liu2025towards, + title={Towards Understanding the Robustness of Diffusion-Based Purification: A Stochastic Perspective}, + author={Yiming Liu, Kezhao Liu, Yao Xiao, ZiYi Dong, Xiaogang Xu, Pengxu Wei, Liang Lin}, + booktitle={The Thirteenth International Conference on Learning Representations}, + year={2025}, + url={https://openreview.net/forum?id=shqjOIK3SA} +} ``` -where the optional flags activate respectively the *individual* evaluations (all the attacks are run on the full test set) and the *version* of AA to use (see below). - -## Other versions -### AutoAttack+ -A more expensive evaluation can be used specifying `version='plus'` when initializing AutoAttack. This includes -+ *untargeted APGD-CE* (5 restarts), -+ *untargeted APGD-DLR* (5 restarts), -+ *untargeted FAB* (5 restarts), -+ *Square Attack* (5000 queries), -+ *targeted APGD-DLR* (9 target classes), -+ *targeted FAB* (9 target classes). -### Randomized defenses -In case of classifiers with stochastic components one can combine AA with Expectation over Transformation (EoT) as in [(Athalye et al., 2018)](https://arxiv.org/abs/1802.00420) specifying `version='rand'` when initializing AutoAttack. -This runs -+ *untargeted APGD-CE* (no restarts, 20 iterations for EoT), -+ *untargeted APGD-DLR* (no restarts, 20 iterations for EoT). - -### Custom version -It is possible to customize the attacks to run specifying `version='custom'` when initializing the attack and then, for example, -```python -if args.version == 'custom': - adversary.attacks_to_run = ['apgd-ce', 'fab'] - adversary.apgd.n_restarts = 2 - adversary.fab.n_restarts = 2 -``` - -## Other options -### Random seed -It is possible to fix the random seed used for the attacks with, e.g., `adversary.seed = 0`. In this case the same seed is used for all the attacks used, otherwise a different random seed is picked for each attack. - -### Log results -To log the intermediate results of the evaluation specify `log_path=/path/to/logfile.txt` when initializing the attack. - -## Citation +**Origional AutoAttack:** ``` @inproceedings{croce2020reliable, title = {Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks}, @@ -269,11 +36,6 @@ To log the intermediate results of the evaluation specify `log_path=/path/to/log } ``` -``` -@inproceedings{croce2021mind, - title={Mind the box: $l_1$-APGD for sparse adversarial attacks on image classifiers}, - author={Francesco Croce and Matthias Hein}, - booktitle={ICML}, - year={2021} -} -``` +--- + +**Check Our Adversarial Denoising Diffusion Training (ADDT) in [https://github.com/LYMDLUT/ADDT](https://github.com/LYMDLUT/ADDT)** diff --git a/autoattack/autoattack.py b/autoattack/autoattack.py index e633308..c8fdd4b 100644 --- a/autoattack/autoattack.py +++ b/autoattack/autoattack.py @@ -11,7 +11,7 @@ class AutoAttack(): def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, - attacks_to_run=[], version='standard', is_tf_model=False, + attacks_to_run=[], version='standard', eval_iter = None, is_tf_model=False, device='cuda', log_path=None): self.model = model self.norm = norm @@ -21,6 +21,7 @@ def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, self.verbose = verbose self.attacks_to_run = attacks_to_run self.version = version + self.eval_iter = eval_iter self.is_tf_model = is_tf_model self.device = device self.logger = Logger(log_path) @@ -108,9 +109,19 @@ def run_standard_evaluation(self, self.logger.log('{} was/were already run.'.format(', '.join(state.run_attacks))) # checks on type of defense + is_randomized_defense = True if self.version != 'rand': - checks.check_randomized(self.get_logits, x_orig[:bs].to(self.device), + is_randomized_defense = checks.is_randomized(self.get_logits, x_orig[:bs].to(self.device), y_orig[:bs].to(self.device), bs=bs, logger=self.logger) + + if self.eval_iter is None: + if is_randomized_defense: + self.logger.log("random defense, using default eval_iter 20") + self.eval_iter = 20 + else: + self.logger.log("non-random defense, using default eval_iter 1") + self.eval_iter = 1 + n_cls = checks.check_range_output(self.get_logits, x_orig[:bs].to(self.device), logger=self.logger) checks.check_dynamic(self.model, x_orig[:bs].to(self.device), self.is_tf_model, @@ -122,19 +133,23 @@ def run_standard_evaluation(self, # calculate accuracy n_batches = int(np.ceil(x_orig.shape[0] / bs)) if state.robust_flags is None: - robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device) + #robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device) + robust_flags = torch.zeros(x_orig.shape[0], device=x_orig.device) y_adv = torch.empty_like(y_orig) for batch_idx in range(n_batches): start_idx = batch_idx * bs - end_idx = min( (batch_idx + 1) * bs, x_orig.shape[0]) + end_idx = min((batch_idx + 1) * bs, x_orig.shape[0]) x = x_orig[start_idx:end_idx, :].clone().to(self.device) y = y_orig[start_idx:end_idx].clone().to(self.device) - output = self.get_logits(x).max(dim=1)[1] - y_adv[start_idx: end_idx] = output - correct_batch = y.eq(output) - robust_flags[start_idx:end_idx] = correct_batch.detach().to(robust_flags.device) + for _ in range(self.eval_iter): + output = self.get_logits(x).max(dim=1)[1] + y_adv[start_idx: end_idx] = output + correct_batch = y.eq(output) + robust_flags[start_idx:end_idx] += correct_batch.detach().to(robust_flags.device) + + robust_flags /= self.eval_iter state.robust_flags = robust_flags robust_accuracy = torch.sum(robust_flags).item() / x_orig.shape[0] robust_accuracy_dict = {'clean': robust_accuracy} @@ -154,7 +169,8 @@ def run_standard_evaluation(self, startt = time.time() for attack in attacks_to_run: # item() is super important as pytorch int division uses floor rounding - num_robust = torch.sum(robust_flags).item() + #num_robust = torch.sum(robust_flags).item() + num_robust = torch.sum(robust_flags != 0).item() if num_robust == 0: break @@ -218,17 +234,31 @@ def run_standard_evaluation(self, else: raise ValueError('Attack not supported') - output = self.get_logits(adv_curr).max(dim=1)[1] - false_batch = ~y.eq(output).to(robust_flags.device) - non_robust_lin_idcs = batch_datapoint_idcs[false_batch] - robust_flags[non_robust_lin_idcs] = False - state.robust_flags = robust_flags + # output = self.get_logits(adv_curr).max(dim=1)[1] + # false_batch = ~y.eq(output).to(robust_flags.device) + # non_robust_lin_idcs = batch_datapoint_idcs[false_batch] + # robust_flags[non_robust_lin_idcs] = False + # state.robust_flags = robust_flags - x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device) - y_adv[non_robust_lin_idcs] = output[false_batch].detach().to(x_adv.device) + # x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device) + # y_adv[non_robust_lin_idcs] = output[false_batch].detach().to(x_adv.device) + + correct_batch = torch.zeros_like(y) + for _ in range(self.eval_iter): + output = self.get_logits(adv_curr).max(dim=1)[1] + correct_batch += y.eq(output).to(robust_flags.device) + + correct_batch = correct_batch / self.eval_iter + + smaller_indices = correct_batch < robust_flags[batch_datapoint_idcs] + robust_flags[batch_datapoint_idcs[smaller_indices]] = correct_batch[smaller_indices] + x_adv[batch_datapoint_idcs[smaller_indices]] = adv_curr[smaller_indices].detach().to(x_adv.device) + y_adv[batch_datapoint_idcs[smaller_indices]] = output[smaller_indices].detach().to(x_adv.device) + if self.verbose: - num_non_robust_batch = torch.sum(false_batch) + #num_non_robust_batch = torch.sum(false_batch) + num_non_robust_batch = torch.sum(1 - correct_batch) self.logger.log('{} - {}/{} - {} out of {} successfully perturbed'.format( attack, batch_idx + 1, n_batches, num_non_robust_batch, x.shape[0])) diff --git a/autoattack/checks.py b/autoattack/checks.py index 964a479..2cb9927 100644 --- a/autoattack/checks.py +++ b/autoattack/checks.py @@ -15,7 +15,7 @@ checks_doc_path = 'flags_doc.md' -def check_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None): +def is_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None): acc = [] corrcl = [] outputs = [] @@ -39,6 +39,8 @@ def check_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None): warnings.warn(Warning(msg)) else: logger.log(f'Warning: {msg}') + return True + return False def check_range_output(model, x, alpha=1e-5, logger=None):