This is a PyTorch implementation of the CLSP paper (ECCV'24).
- The linear evaluation accuracy we get during training is similar to the offline setting.
- For Offline kNN evaluation, we do a grid search on k and report the best top1 accuracy.
- Accuracy might be different from the number in the paper because of random initialization.
- Most offline linear evaluation results for baselines are directly copied from solo-learn. For offline kNN, we use their pre-trained checkpoint on our kNN evaluation code.
| Method | Backbone | Epochs | Offline Linear Acc@1 | Offline kNN | Tensorboard | Checkpoint |
|---|---|---|---|---|---|---|
| CLSP-SimCLR | ResNet18 | 1000 | 94.45 | 94.10 | 🏃 | 🔗 |
| CLSP-MoCo | ResNet18 | 1000 | 94.62 | 94.01 | 🏃 | 🔗 |
| CLSP-SimCLR | ResNet50 | 1000 | 95.57 | 95.14 | 🏃 | 🔗 |
| CLSP-MoCo | ResNet50 | 1000 | 96.04 | 95.33 | 🏃 | 🔗 |
| SimCLR | ResNet18 | 1000 | 91.45 | 91.05 | 🏃 | 🔗 |
| MoCo | ResNet18 | 1000 | 92.29 | 91.74 | 🏃 | 🔗 |
| SimCLR | ResNet50 | 1000 | 93.03 (94.00) | 92.17 | 🏃 | 🔗 |
| MoCo | ResNet50 | 1000 | 94.62 | 94.33 | 🏃 | 🔗 |
| Method | Backbone | Epochs | Offline Linear Acc@1 | Offline kNN | Tensorboard | Checkpoint |
|---|---|---|---|---|---|---|
| CLSP-SimCLR | ResNet18 | 1000 | 72.41 | 70.88 | 🏃 | 🔗 |
| CLSP-MoCo | ResNet18 | 1000 | 72.33 | 70.13 | 🏃 | 🔗 |
| CLSP-SimCLR | ResNet50 | 1000 | 74.65 | 72.81 | 🏃 | 🔗 |
| CLSP-MoCo | ResNet50 | 1000 | 76.43 | 73.76 | 🏃 | 🔗 |
| SimCLR | ResNet18 | 1000 | 68.10 | 66.24 | 🏃 | 🔗 |
| MoCo | ResNet18 | 1000 | 68.91 | 68.11 | 🏃 | 🔗 |
| SimCLR | ResNet50 | 1000 | 70.16 | 69.21 | 🏃 | 🔗 |
| MoCo | ResNet50 | 1000 | 72.68 | 72.25 | 🏃 | 🔗 |
| Method | Backbone | Epochs | Offline Linear Acc@1 | Tensorboard | Checkpoint |
|---|---|---|---|---|---|
| CLSP-SimCLR | ResNet18 | 1000 | 94.74 | 🏃 | 🔗 |
| CLSP-MoCo | ResNet18 | 1000 | 94.31 | 🏃 | 🔗 |
| SimCLR | ResNet18 | 1000 | 91.31 | 🏃 | 🔗 |
| MoCo | ResNet18 | 1000 | 92.61 | 🏃 | 🔗 |
Our ddpm training and sampling code are adapted from this repo, ddim sampling code is adapted from this repo.
To train diffusion models for Cifar10 and Cifar100 datasets
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_diffusion.py --config configs/diffusion_cifar10.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_diffusion.py --config configs/diffusion_cifar100.yaml
To compute the FID of the pre-trained diffusion model (let's say Cifar10), first using the pre-trained diffusion model to randomly sample 50000 images using generate_synthetic_dataset.py, then download the numpy version of Cifar10 from link, use compute_fid_score.py to compute the fid score.
Generate random samples
CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset.py \
--ckpt_path ./pretrained_ckpts/uncondition_diffusion_cifar10.pt \
--config configs/diffusion_cifar10.yaml \
--save_dir ./synthetic_datasets/cifar10_tmp/ --num_candidates 1 --batch_size 4096 \
--sample_method ddpm
Compute FID score
CUDA_VISIBLE_DEVICES=0 python compute_fid_score.py \
--real_samples ./data_numpy/cifar10.npy \
--fake_samples ./synthetic_datasets/cifar10_tmp/generated_ddpm_0.01_000.npy
| Dataset | Epochs | FID score | Checkpoint |
|---|---|---|---|
| Cifar10 | 2000 | 3.58 | 🔗 |
| Cifar100 | 2000 | 6.06 | 🔗 |
We use ddim to speedup the sampling process, to generate 8 synthetic positives for each of the data point in the dataset
On Cifar10 or Cifar100
CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset.py \
--ckpt_path ./pretrained_ckpts/uncondition_diffusion_cifar10.pt \
--config configs/diffusion_cifar10.yaml \
--save_dir ./synthetic_datasets/cifar10/ --num_candidates 8 --batch_size 4096 \
--interpolation_weight 0.1 --sample_method ddim_interpolation \
--ddim_sampling_timesteps 100 --ddim_eta 0.1 --num_interpolation_layers 1
On STL10
CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset.py \
--ckpt_path ./pretrained_ckpts/uncondition_diffusion_stl10.pt \
--config configs/diffusion_stl10.yaml \
--save_dir ./synthetic_datasets/stl10/ --num_candidates 8 --batch_size 4096 \
--interpolation_weight 0.1 --sample_method ddim_interpolation \
--ddim_sampling_timesteps 200 --ddim_eta 0.1 --num_interpolation_layers 4
| Dataset | Synthetic Dataset |
|---|---|
| Cifar10 | 🔗 |
| Cifar100 | 🔗 |
| STL10 | 🔗 |
Train vanilla SimCLR model on Cifar10 dataset
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_simclr.py --config configs/simclr_cifar10.yaml
Train CLSP-SimCLR model on Cifar10 dataset (make sure to set the synthetic_data_path to correct path in the configs/simclr_clsp_cifar10.yaml) Use your own generated synthetic dataset or download the dataset from Step 2.
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_simclr_clsp.py --config configs/simclr_clsp_cifar10.yaml
There are two ways to utilize class label to help SSL training
- Class-conditional Diffusion Model
- Supervised Contrastive Learning (SupContrast)
- Train a class-conditional diffusion model with the classifer-free diffusion guidance
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_conditional_diffusion.py --config configs/conditional_diffusion_cifar10.yaml
- Generate 8 synthetic positive images for each sample in the dataset
CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset_conditional.py \
--ckpt_path your_model_path \
--config configs/conditional_diffusion_cifar10.yaml \
--save_dir your_save_dir --num_candidates 8 --batch_size 4096 \
--sample_method ddim --ddim_sampling_timesteps 100 --ddim_eta 0.1
- Train SSL with the synthetic positives (Use the same code as CLSP)
Pre-trained class-conditional diffusion model
| Dataset | Epochs | Checkpoint | Synthetic Dataset |
|---|---|---|---|
| Cifar10 | 1000 | 🔗 | 🔗 |
| Cifar100 | 1000 | 🔗 | 🔗 |
Results (Using synthetic positives generated by class-conditional diffusion model)
| Dataset | Method | Backbone | Epochs | Offline Acc@1 |
|---|---|---|---|---|
| Cifar10 | SimCLR | ResNet18 | 1000 | 94.44 |
| Cifar10 | MoCo | ResNet18 | 1000 | 94.47 |
| Cifar100 | SimCLR | ResNet18 | 1000 | 75.45 |
| Cifar100 | MoCo | ResNet18 | 1000 | 73.06 |
Training command
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_simclr_supervised.py --config configs/simclr_supervised_cifar10.yaml
Results
| Dataset | Backbone | Epochs | Offline Acc@1 |
|---|---|---|---|
| Cifar10 | ResNet18 | 1000 | 94.61 |
| Cifar100 | ResNet18 | 1000 | 73.89 |
If you use this code useful, please cite our paper
@inproceedings{zeng2024contrastive,
title={Contrastive learning with synthetic positives},
author={Zeng, Dewen and Wu, Yawen and Hu, Xinrong and Xu, Xiaowei and Shi, Yiyu},
booktitle={European Conference on Computer Vision},
pages={430--447},
year={2024},
organization={Springer}
}


