Skip to content

Commit 8dcf4fb

Browse files
committed
Initial commit
1 parent 0cd0420 commit 8dcf4fb

File tree

157 files changed

+93180
-44
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

157 files changed

+93180
-44
lines changed

.gitignore

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,10 @@
1-
# Compiled Lua sources
2-
luac.out
3-
4-
# luarocks build files
5-
*.src.rock
6-
*.zip
7-
*.tar.gz
8-
9-
# Object files
10-
*.o
11-
*.os
12-
*.ko
13-
*.obj
14-
*.elf
15-
16-
# Precompiled Headers
17-
*.gch
18-
*.pch
19-
20-
# Libraries
21-
*.lib
22-
*.a
23-
*.la
24-
*.lo
25-
*.def
26-
*.exp
27-
28-
# Shared objects (inc. Windows DLLs)
29-
*.dll
30-
*.so
31-
*.so.*
32-
*.dylib
33-
34-
# Executables
35-
*.exe
36-
*.out
37-
*.app
38-
*.i*86
39-
*.x86_64
40-
*.hex
41-
1+
mnist_lr/mnist/logs/*.log
2+
mnist_minibatch/mnist/logs/*.log
3+
mnist_lr/dqn/mnist.t7/
4+
mnist_minibatch/dqn/mnist.t7/
5+
tmp/
6+
mnist_minibatch/dqn/logs/*.log
7+
mnist_minibatch/dqn/logs/*.pdf
8+
mnist_lr/mnist/mnist.t7/
9+
mnist_minibatch/mnist/mnist.t7/
10+
mnist_minibatch/dqn/logs/paint.py

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
The MIT License (MIT)
22

3-
Copyright (c) 2016 bigAIdream projects
3+
Copyright (c) 2016 Jie Fu, Zichuan Lin, Miao Liu, Nicholas Leonard, Jiashi Feng, Tat-Seng Chua projects
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,57 @@
1-
# qan
2-
Deep Q-Networks for Accelerating the Training of Deep Neural Networks
1+
# Deep Q-Networks for Accelerating the Training of Deep Neural Networks
2+
3+
Source code to the paper [Deep Q-Networks for Accelerating the Training of Deep Neural Networks](https://arxiv.org/abs/1606.01467)
4+
5+
## Reproduce our results on MNIST
6+
7+
### Dependencies
8+
We are using Torch. The DQN component is mostly modified from [DeepMind Atari DQN](https://github.com/kuz/DeepMind-Atari-Deep-Q-Learner).
9+
10+
You might need to run `install_dependencies.sh` first.
11+
12+
### Tuning learning rates on MNIST
13+
```bash
14+
cd mnist_lr/;
15+
cd mnist;
16+
th train-on-mnist.lua; #get regression filter, save in ../save/
17+
./run_gpu; #Start tune learning rate using dqn
18+
#To get the test curve, run following command
19+
cd mnist_lr/dqn/logs;
20+
python paint_lr_episode.py;
21+
python paint_lr_vs.py;
22+
```
23+
24+
### Tuning mini-batch selection on MNIST
25+
```bash
26+
cd mnist_minibatch;
27+
cd mnist;
28+
th train-on-mnist.lua; #get regression filter, save in ../save/
29+
./run_gpu; #Start select mini-batch using dqn
30+
#To get the test curve, run following command
31+
cd mnist_minibatch/dqn/logs;
32+
python paint_mini_episode.py;
33+
python paint_mini_vs.py;
34+
```
35+
36+
### Different Settings
37+
1. GPU device can be set in `run_gpu` where `gpu=0`
38+
2. Learning rate can be set in `/ataricifar/dqn/cnnGameEnv.lua`, in the `step` function.
39+
3. When to stop doing regression is in `/ataricifar/dqn/cnnGameEnv/lua`, in line 250
40+
41+
## TODO
42+
1. Experiments on CIFAR
43+
2. Transfer learning
44+
45+
## Citation
46+
```
47+
@article{dqn-accelerate-dnn,
48+
title={Deep Q-Networks for Accelerating the Training of Deep Neural Networks},
49+
author={Fu, Jie and Lin, Zichuan and Liu, Miao and Leonard, Nicholas and Feng, Jiashi and Chua, Tat-Seng},
50+
journal={arXiv preprint arXiv:1606.01467},
51+
year={2016}
52+
}
53+
```
54+
55+
## Contact
56+
57+
If you have any problems or suggestions, please contact me: jie.fu A~_~T u.nus.edu~~cation~~

cifar_lr/cifar.torch/LICENSE

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
The MIT License (MIT)
2+
3+
Copyright (c) 2015 Sergey Zagoruyko
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
22+

cifar_lr/cifar.torch/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# cifar.torch
2+
3+
The code achieves 92.45% accuracy on CIFAR-10 just with horizontal reflections.
4+
5+
Corresponding blog post: http://torch.ch/blog/2015/07/30/cifar.html
6+
7+
Accuracies:
8+
9+
| No flips | Flips
10+
--- | --- | ---
11+
VGG+BN+Dropout | 91.3% | 92.45%
12+
NIN+BN+Dropout | 90.4% | 91.9%
13+
14+
Would be nice to add other architectures, PRs are welcome!
15+
16+
Data preprocessing:
17+
18+
```bash
19+
OMP_NUM_THREADS=2 th -i provider.lua
20+
```
21+
22+
```lua
23+
provider = Provider()
24+
provider:normalize()
25+
torch.save('provider.t7',provider)
26+
```
27+
Takes about 30 seconds and saves 1400 Mb file.
28+
29+
Training:
30+
31+
```bash
32+
CUDA_VISIBLE_DEVICES=0 th train.lua --model vgg_bn_drop -s logs/vgg
33+
```
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
require 'image'
2+
require 'cudnn'
3+
require 'cunn'
4+
local tablex = require 'pl.tablex'
5+
6+
if #arg < 2 then
7+
io.stderr:write('Usage: th example_classify.lua [MODEL] [FILE]...\n')
8+
os.exit(1)
9+
end
10+
for _, f in ipairs(arg) do
11+
if not paths.filep(f) then
12+
io.stderr:write('file not found: ' .. f .. '\n')
13+
os.exit(1)
14+
end
15+
end
16+
17+
local model_path = arg[1]
18+
local image_paths = tablex.sub(arg, 2, -1)
19+
20+
-- loads the normalization parameters
21+
require 'provider'
22+
local provider = torch.load 'provider.t7'
23+
24+
local function normalize(imgRGB)
25+
26+
-- preprocess trainSet
27+
local normalization = nn.SpatialContrastiveNormalization(1, image.gaussian1D(7)):float()
28+
29+
-- rgb -> yuv
30+
local yuv = image.rgb2yuv(imgRGB)
31+
-- normalize y locally:
32+
yuv[1] = normalization(yuv[{{1}}])
33+
34+
-- normalize u globally:
35+
local mean_u = provider.trainData.mean_u
36+
local std_u = provider.trainData.std_u
37+
yuv:select(1,2):add(-mean_u)
38+
yuv:select(1,2):div(std_u)
39+
-- normalize v globally:
40+
local mean_v = provider.trainData.mean_v
41+
local std_v = provider.trainData.std_v
42+
yuv:select(1,3):add(-mean_v)
43+
yuv:select(1,3):div(std_v)
44+
45+
return yuv
46+
end
47+
48+
local model = torch.load(model_path)
49+
model:add(nn.SoftMax():cuda())
50+
model:evaluate()
51+
52+
-- model definition should set numInputDims
53+
-- hacking around it for the moment
54+
local view = model:findModules('nn.View')
55+
if #view > 0 then
56+
view[1].numInputDims = 3
57+
end
58+
59+
local cls = {'airplane', 'automobile', 'bird', 'cat',
60+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
61+
62+
for _, img_path in ipairs(image_paths) do
63+
-- load image
64+
local img = image.load(img_path, 3, 'float'):mul(255)
65+
66+
-- resize it to 32x32
67+
img = image.scale(img, 32, 32)
68+
-- normalize
69+
img = normalize(img)
70+
-- make it batch mode (for BatchNormalization)
71+
img = img:view(1, 3, 32, 32)
72+
73+
-- get probabilities
74+
local output = model:forward(img:cuda()):squeeze()
75+
76+
-- display
77+
print('Probabilities for '..img_path)
78+
for cl_id, cl in ipairs(cls) do
79+
print(string.format('%-10s: %-05.2f%%', cl, output[cl_id] * 100))
80+
end
81+
end

0 commit comments

Comments
 (0)