Created and maintained by Joshua Ning
Download data from artbench
Download pre-trained model from google drive and move it into /models
Use Conda Python 3.9.7
- matplotlib
- torch
- numpy
- pandas
- torchinfo
- torchmetrics
Only data_256 is required to start training
model_train_val.ipynb can train and validate, is optimized for GPU with CUDA
351.ipynb is a notebook with more documentation, essentially the same as model_train_val.ipynb
train.py is a python script that can only train(for server)
The stats and models will be saved in the respective directories

- make sure all the dependencies are installed
- open
model_train_val.ipynb - run code blocks 1-4
- For training only:
- In code block 5, if training for the first time, change
save_pointvariable to 0 - In code block 5, if you want to continue training from a previous model, enter the last epoch number
- Skip code block 6 if training for the first time
- Change the hyper-parameter variables on top of code block 7 if needed
- Change the output file directory in code block 7 on lines 34 and 38 if needed
- In code block 5, if training for the first time, change
- For evaluation only:
- Code block 8 is for evaluating multiple epochs of a trained model
- Make sure line 13 file path is consistent with the saved model path
- Code block 9 is for evaluating single model
- Make sure line 6 file path is consistent with the saved model path
- Code block 10 is for generating a decision matrix, which must be run after code block 9
- Code block 8 is for evaluating multiple epochs of a trained model
test_v1.pthA small model from tutorial to test loading and saving data
Trained for 1 epoch, accuracy of 10%test_v2.pthmini-VGG BN
Trained for 3 epoch, batch size = 32, accuracy of 40%epoch_19_bs16VGG-16 BN
Trained for 19 epoch, batch size = 16, accuracy of 53.9%
stats.csvmini-VGG BN stats/stats_saves_vgg_16VGG 16 batch size = 16 training stats/stats_saves_vgg16_32VGG 16 batch size = 32 training statstest_eval_stats_32.csvVGG 16 batch size = 32 evaluation statstest_eval_stats.csvVGG 16 batch size = 16 evaluation statsstats_plot.mPlot data with MATLAB
Plots and graphs generated throughout the project