Skip to content

Commit 3f4b1a9

Browse files
committed
pulish
0 parents  commit 3f4b1a9

File tree

856 files changed

+373776
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

856 files changed

+373776
-0
lines changed

LICENSE

+279
Large diffs are not rendered by default.

README.md

+454
Large diffs are not rendered by default.

analysis/Demos/Demo_ACT.ipynb

+322
Large diffs are not rendered by default.

analysis/Demos/Demo_ActDist.ipynb

+506
Large diffs are not rendered by default.

analysis/Demos/Demo_CM.ipynb

+320
Large diffs are not rendered by default.

analysis/Demos/Demo_FM.ipynb

+258,389
Large diffs are not rendered by default.

analysis/Demos/Demo_FV.ipynb

+287
Large diffs are not rendered by default.

analysis/Demos/Demo_Frequency.ipynb

+387
Large diffs are not rendered by default.

analysis/Demos/Demo_GradCam.ipynb

+403
Large diffs are not rendered by default.

analysis/Demos/Demo_Hessian.ipynb

+372
Large diffs are not rendered by default.

analysis/Demos/Demo_Landscape.ipynb

+544
Large diffs are not rendered by default.

analysis/Demos/Demo_Lips.ipynb

+478
Large diffs are not rendered by default.

analysis/Demos/Demo_Neuron_Activation.ipynb

+404
Large diffs are not rendered by default.

analysis/Demos/Demo_Quality.ipynb

+308
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "ebd66700",
6+
"metadata": {},
7+
"source": [
8+
"## Demo_Quality\n",
9+
"This is a demo for visualizing the Image Quality\n",
10+
"\n",
11+
"To run this demo from scratch, you need first generate a BadNet attack result by using the following cell"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"id": "b950f4fc",
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"! python ../../attack/badnet.py --save_folder_name badnet_demo"
22+
]
23+
},
24+
{
25+
"cell_type": "markdown",
26+
"id": "8f81f973",
27+
"metadata": {},
28+
"source": [
29+
"or run the following command in your terminal\n",
30+
"\n",
31+
"```python attack/badnet.py --save_folder_name badnet_demo```"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"id": "87bd9f5a",
37+
"metadata": {},
38+
"source": [
39+
"### Step 1: Import modules and set arguments"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": 1,
45+
"id": "71b7087b",
46+
"metadata": {},
47+
"outputs": [],
48+
"source": [
49+
"import sys, os\n",
50+
"import yaml\n",
51+
"import torch\n",
52+
"import shap\n",
53+
"import numpy as np\n",
54+
"import torchvision.transforms as transforms\n",
55+
"\n",
56+
"sys.path.append(\"../\")\n",
57+
"sys.path.append(\"../../\")\n",
58+
"sys.path.append(os.getcwd())\n",
59+
"from visual_utils import *\n",
60+
"from utils.aggregate_block.dataset_and_transform_generate import (\n",
61+
" get_transform,\n",
62+
" get_dataset_denormalization,\n",
63+
")\n",
64+
"from utils.aggregate_block.fix_random import fix_random\n",
65+
"from utils.aggregate_block.model_trainer_generate import generate_cls_model\n",
66+
"from utils.save_load_attack import load_attack_result\n",
67+
"from utils.defense_utils.dbd.model.utils import (\n",
68+
" get_network_dbd,\n",
69+
" load_state,\n",
70+
" get_criterion,\n",
71+
" get_optimizer,\n",
72+
" get_scheduler,\n",
73+
")\n",
74+
"from utils.defense_utils.dbd.model.model import SelfModel, LinearModel\n"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 2,
80+
"id": "2fb719c7",
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"### Basic setting: args\n",
85+
"args = get_args(True)\n",
86+
"\n",
87+
"########## For Demo Only ##########\n",
88+
"args.yaml_path = \"../../\"+args.yaml_path\n",
89+
"args.result_file_attack = \"badnet_demo\"\n",
90+
"######## End For Demo Only ##########\n",
91+
"\n",
92+
"with open(args.yaml_path, \"r\") as stream:\n",
93+
" config = yaml.safe_load(stream)\n",
94+
"config.update({k: v for k, v in args.__dict__.items() if v is not None})\n",
95+
"args.__dict__ = config\n",
96+
"args = preprocess_args(args)\n",
97+
"fix_random(int(args.random_seed))\n",
98+
"\n",
99+
"save_path_attack = \"../..//record/\" + args.result_file_attack\n"
100+
]
101+
},
102+
{
103+
"cell_type": "markdown",
104+
"id": "f959b510",
105+
"metadata": {},
106+
"source": [
107+
"### Step 2: Load data"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": 3,
113+
"id": "b8b67ac9",
114+
"metadata": {},
115+
"outputs": [
116+
{
117+
"name": "stderr",
118+
"output_type": "stream",
119+
"text": [
120+
"WARNING:root:save_path MUST have 'record' in its abspath, and data_path in attack result MUST have 'data' in its path\n"
121+
]
122+
},
123+
{
124+
"name": "stdout",
125+
"output_type": "stream",
126+
"text": [
127+
"Files already downloaded and verified\n",
128+
"Files already downloaded and verified\n",
129+
"loading...\n",
130+
"max_num_samples is given, use sample number limit now.\n",
131+
"subset bd dataset with length: 5000\n",
132+
"Create visualization dataset with \n",
133+
" \t Dataset: bd_train \n",
134+
" \t Number of samples: 5000 \n",
135+
" \t Selected classes: [0 1 2 3 4 5 6 7 8 9]\n"
136+
]
137+
}
138+
],
139+
"source": [
140+
"# Load result\n",
141+
"result_attack = load_attack_result(save_path_attack + \"/attack_result.pt\")\n",
142+
"selected_classes = np.arange(args.num_classes)\n",
143+
"\n",
144+
"# Select classes to visualize\n",
145+
"if args.num_classes>args.c_sub:\n",
146+
" selected_classes = np.delete(selected_classes, args.target_class)\n",
147+
" selected_classes = np.random.choice(selected_classes, args.c_sub-1, replace=False)\n",
148+
" selected_classes = np.append(selected_classes, args.target_class)\n",
149+
"\n",
150+
"# keep the same transforms for train and test dataset for better visualization\n",
151+
"result_attack[\"clean_train\"].wrap_img_transform = result_attack[\"clean_test\"].wrap_img_transform \n",
152+
"result_attack[\"bd_train\"].wrap_img_transform = result_attack[\"bd_test\"].wrap_img_transform \n",
153+
"\n",
154+
"# Create dataset\n",
155+
"args.visual_dataset = 'bd_train'\n",
156+
"if args.visual_dataset == 'mixed':\n",
157+
" bd_test_with_trans = result_attack[\"bd_test\"]\n",
158+
" visual_dataset = generate_mix_dataset(bd_test_with_trans, args.target_class, args.pratio, selected_classes, max_num_samples=args.n_sub)\n",
159+
"elif args.visual_dataset == 'clean_train':\n",
160+
" clean_train_with_trans = result_attack[\"clean_train\"]\n",
161+
" visual_dataset = generate_clean_dataset(clean_train_with_trans, selected_classes, max_num_samples=args.n_sub)\n",
162+
"elif args.visual_dataset == 'clean_test':\n",
163+
" clean_test_with_trans = result_attack[\"clean_test\"]\n",
164+
" visual_dataset = generate_clean_dataset(clean_test_with_trans, selected_classes, max_num_samples=args.n_sub)\n",
165+
"elif args.visual_dataset == 'bd_train': \n",
166+
" bd_train_with_trans = result_attack[\"bd_train\"]\n",
167+
" visual_dataset = generate_bd_dataset(bd_train_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)\n",
168+
"elif args.visual_dataset == 'bd_test':\n",
169+
" bd_test_with_trans = result_attack[\"bd_test\"]\n",
170+
" visual_dataset = generate_bd_dataset(bd_test_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)\n",
171+
"else:\n",
172+
" assert False, \"Illegal vis_class\"\n",
173+
"\n",
174+
"print(f'Create visualization dataset with \\n \\t Dataset: {args.visual_dataset} \\n \\t Number of samples: {len(visual_dataset)} \\n \\t Selected classes: {selected_classes}')\n",
175+
"\n",
176+
"# Create data loader\n",
177+
"data_loader = torch.utils.data.DataLoader(\n",
178+
" visual_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False\n",
179+
")\n",
180+
"\n",
181+
"# Create denormalization function\n",
182+
"for trans_t in data_loader.dataset.wrap_img_transform.transforms:\n",
183+
" if isinstance(trans_t, transforms.Normalize):\n",
184+
" denormalizer = get_dataset_denormalization(trans_t)\n",
185+
"\n"
186+
]
187+
},
188+
{
189+
"cell_type": "markdown",
190+
"id": "67cbfec4",
191+
"metadata": {},
192+
"source": [
193+
"### Step 3: SSIM"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": 4,
199+
"id": "39104beb",
200+
"metadata": {},
201+
"outputs": [
202+
{
203+
"name": "stdout",
204+
"output_type": "stream",
205+
"text": [
206+
"Number Poisoned samples: 489\n",
207+
"Average SSIM: 0.9929845929145813\n"
208+
]
209+
}
210+
],
211+
"source": [
212+
"visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))\n",
213+
"bd_idx = np.where(visual_poison_indicator == 1)[0]\n",
214+
"\n",
215+
"from torchmetrics import StructuralSimilarityIndexMeasure\n",
216+
"ssim = StructuralSimilarityIndexMeasure()\n",
217+
"ssim_list = []\n",
218+
"if visual_poison_indicator.sum() > 0:\n",
219+
" print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')\n",
220+
" # random choose two poisoned samples\n",
221+
" start_idx = 0\n",
222+
" for i in range(bd_idx.shape[0]):\n",
223+
" bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
224+
" with temporary_all_clean(visual_dataset):\n",
225+
" clean_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
226+
" ssim_list.append(ssim(bd_sample, clean_sample)) \n",
227+
"print(f'Average SSIM: {np.mean(ssim_list)}')"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"id": "2c2b0104",
233+
"metadata": {},
234+
"source": [
235+
"### Step 4: FID"
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": 5,
241+
"id": "57497927",
242+
"metadata": {},
243+
"outputs": [
244+
{
245+
"name": "stdout",
246+
"output_type": "stream",
247+
"text": [
248+
"Number Poisoned samples: 489\n",
249+
"FID: 0.00030133521067909896\n"
250+
]
251+
}
252+
],
253+
"source": [
254+
"visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))\n",
255+
"bd_idx = np.where(visual_poison_indicator == 1)[0]\n",
256+
"\n",
257+
"from torchmetrics.image.fid import FrechetInceptionDistance\n",
258+
"fid = FrechetInceptionDistance(feature=64, normalize = True)\n",
259+
"if visual_poison_indicator.sum() > 0:\n",
260+
" print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')\n",
261+
" # random choose two poisoned samples\n",
262+
" start_idx = 0\n",
263+
" for i in range(bd_idx.shape[0]):\n",
264+
" bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
265+
" with temporary_all_clean(visual_dataset):\n",
266+
" clean_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)\n",
267+
" fid.update(clean_sample, real=True)\n",
268+
" fid.update(bd_sample, real=False)\n",
269+
" fid_value = fid.compute().numpy() \n",
270+
"print(f'FID: {fid_value}')"
271+
]
272+
},
273+
{
274+
"cell_type": "code",
275+
"execution_count": null,
276+
"id": "870cf186",
277+
"metadata": {},
278+
"outputs": [],
279+
"source": []
280+
}
281+
],
282+
"metadata": {
283+
"kernelspec": {
284+
"display_name": "Python 3 (ipykernel)",
285+
"language": "python",
286+
"name": "python3"
287+
},
288+
"language_info": {
289+
"codemirror_mode": {
290+
"name": "ipython",
291+
"version": 3
292+
},
293+
"file_extension": ".py",
294+
"mimetype": "text/x-python",
295+
"name": "python",
296+
"nbconvert_exporter": "python",
297+
"pygments_lexer": "ipython3",
298+
"version": "3.9.12"
299+
},
300+
"vscode": {
301+
"interpreter": {
302+
"hash": "6869619afde5ccaa692f7f4d174735a0f86b1f7ceee086952855511b0b6edec0"
303+
}
304+
}
305+
},
306+
"nbformat": 4,
307+
"nbformat_minor": 5
308+
}

0 commit comments

Comments
 (0)