Skip to content

dewenzeng/clsp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Contrastive Learning with Synthetic Positives

This is a PyTorch implementation of the CLSP paper (ECCV'24).

🚀 Main Results

  • The linear evaluation accuracy we get during training is similar to the offline setting.
  • For Offline kNN evaluation, we do a grid search on k and report the best top1 accuracy.
  • Accuracy might be different from the number in the paper because of random initialization.
  • Most offline linear evaluation results for baselines are directly copied from solo-learn. For offline kNN, we use their pre-trained checkpoint on our kNN evaluation code.

Cifar10

Method Backbone Epochs Offline Linear Acc@1 Offline kNN Tensorboard Checkpoint
CLSP-SimCLR ResNet18 1000 94.45 94.10 🏃 🔗
CLSP-MoCo ResNet18 1000 94.62 94.01 🏃 🔗
CLSP-SimCLR ResNet50 1000 95.57 95.14 🏃 🔗
CLSP-MoCo ResNet50 1000 96.04 95.33 🏃 🔗
SimCLR ResNet18 1000 91.45 91.05 🏃 🔗
MoCo ResNet18 1000 92.29 91.74 🏃 🔗
SimCLR ResNet50 1000 93.03 (94.00) 92.17 🏃 🔗
MoCo ResNet50 1000 94.62 94.33 🏃 🔗

Cifar100

Method Backbone Epochs Offline Linear Acc@1 Offline kNN Tensorboard Checkpoint
CLSP-SimCLR ResNet18 1000 72.41 70.88 🏃 🔗
CLSP-MoCo ResNet18 1000 72.33 70.13 🏃 🔗
CLSP-SimCLR ResNet50 1000 74.65 72.81 🏃 🔗
CLSP-MoCo ResNet50 1000 76.43 73.76 🏃 🔗
SimCLR ResNet18 1000 68.10 66.24 🏃 🔗
MoCo ResNet18 1000 68.91 68.11 🏃 🔗
SimCLR ResNet50 1000 70.16 69.21 🏃 🔗
MoCo ResNet50 1000 72.68 72.25 🏃 🔗

STL10

Method Backbone Epochs Offline Linear Acc@1 Tensorboard Checkpoint
CLSP-SimCLR ResNet18 1000 94.74 🏃 🔗
CLSP-MoCo ResNet18 1000 94.31 🏃 🔗
SimCLR ResNet18 1000 91.31 🏃 🔗
MoCo ResNet18 1000 92.61 🏃 🔗

1️⃣ Diffusion Model Training

Our ddpm training and sampling code are adapted from this repo, ddim sampling code is adapted from this repo.

To train diffusion models for Cifar10 and Cifar100 datasets

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_diffusion.py --config configs/diffusion_cifar10.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_diffusion.py --config configs/diffusion_cifar100.yaml

To compute the FID of the pre-trained diffusion model (let's say Cifar10), first using the pre-trained diffusion model to randomly sample 50000 images using generate_synthetic_dataset.py, then download the numpy version of Cifar10 from link, use compute_fid_score.py to compute the fid score.

Generate random samples

CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset.py \
--ckpt_path ./pretrained_ckpts/uncondition_diffusion_cifar10.pt \
--config configs/diffusion_cifar10.yaml \
--save_dir ./synthetic_datasets/cifar10_tmp/ --num_candidates 1 --batch_size 4096 \
--sample_method ddpm

Compute FID score

CUDA_VISIBLE_DEVICES=0 python compute_fid_score.py \
--real_samples ./data_numpy/cifar10.npy \
--fake_samples ./synthetic_datasets/cifar10_tmp/generated_ddpm_0.01_000.npy
Dataset Epochs FID score Checkpoint
Cifar10 2000 3.58 🔗
Cifar100 2000 6.06 🔗

2️⃣ Generate Synthetic Positives

We use ddim to speedup the sampling process, to generate 8 synthetic positives for each of the data point in the dataset

On Cifar10 or Cifar100

CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset.py \
--ckpt_path ./pretrained_ckpts/uncondition_diffusion_cifar10.pt \
--config configs/diffusion_cifar10.yaml \
--save_dir ./synthetic_datasets/cifar10/ --num_candidates 8 --batch_size 4096 \
--interpolation_weight 0.1 --sample_method ddim_interpolation \
--ddim_sampling_timesteps 100 --ddim_eta 0.1 --num_interpolation_layers 1

On STL10

CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset.py \
--ckpt_path ./pretrained_ckpts/uncondition_diffusion_stl10.pt \
--config configs/diffusion_stl10.yaml \
--save_dir ./synthetic_datasets/stl10/ --num_candidates 8 --batch_size 4096 \
--interpolation_weight 0.1 --sample_method ddim_interpolation \
--ddim_sampling_timesteps 200 --ddim_eta 0.1 --num_interpolation_layers 4
Dataset Synthetic Dataset
Cifar10 🔗
Cifar100 🔗
STL10 🔗

3️⃣ Self-supervised Training

Train vanilla SimCLR model on Cifar10 dataset

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_simclr.py --config configs/simclr_cifar10.yaml

Train CLSP-SimCLR model on Cifar10 dataset (make sure to set the synthetic_data_path to correct path in the configs/simclr_clsp_cifar10.yaml) Use your own generated synthetic dataset or download the dataset from Step 2.

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_simclr_clsp.py --config configs/simclr_clsp_cifar10.yaml

Using Class Label for Training

There are two ways to utilize class label to help SSL training

  • Class-conditional Diffusion Model
  • Supervised Contrastive Learning (SupContrast)

Class-conditional Diffusion Model

  1. Train a class-conditional diffusion model with the classifer-free diffusion guidance
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_conditional_diffusion.py --config configs/conditional_diffusion_cifar10.yaml
  1. Generate 8 synthetic positive images for each sample in the dataset
CUDA_VISIBLE_DEVICES=0,1,2,3 python generate_synthetic_dataset_conditional.py \
--ckpt_path your_model_path \
--config configs/conditional_diffusion_cifar10.yaml \
--save_dir your_save_dir --num_candidates 8 --batch_size 4096 \
--sample_method ddim --ddim_sampling_timesteps 100 --ddim_eta 0.1
  1. Train SSL with the synthetic positives (Use the same code as CLSP)

Pre-trained class-conditional diffusion model

Dataset Epochs Checkpoint Synthetic Dataset
Cifar10 1000 🔗 🔗
Cifar100 1000 🔗 🔗

Results (Using synthetic positives generated by class-conditional diffusion model)

Dataset Method Backbone Epochs Offline Acc@1
Cifar10 SimCLR ResNet18 1000 94.44
Cifar10 MoCo ResNet18 1000 94.47
Cifar100 SimCLR ResNet18 1000 75.45
Cifar100 MoCo ResNet18 1000 73.06

Supervised Contrastive learning

Training command

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_simclr_supervised.py --config configs/simclr_supervised_cifar10.yaml

Results

Dataset Backbone Epochs Offline Acc@1
Cifar10 ResNet18 1000 94.61
Cifar100 ResNet18 1000 73.89

T-SNE Visualization on Cifar10

Citation

If you use this code useful, please cite our paper

@inproceedings{zeng2024contrastive,
  title={Contrastive learning with synthetic positives},
  author={Zeng, Dewen and Wu, Yawen and Hu, Xinrong and Xu, Xiaowei and Shi, Yiyu},
  booktitle={European Conference on Computer Vision},
  pages={430--447},
  year={2024},
  organization={Springer}
}

❤️ Reference

About

Pytorch implementation of paper "Contrastive Learning with Synthetic Positives"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published