Skip to content

Commit 4fa2459

Browse files
create a script to train autoencoderkl (#10605)
* create a script to train vae * update main.py * update train_autoencoderkl.py * update train_autoencoderkl.py * add a check of --pretrained_model_name_or_path and --model_config_name_or_path * remove the comment, remove diffusers in requiremnets.txt, add validation_image ote * update autoencoderkl.py * quality --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 4f3ec53 commit 4fa2459

File tree

3 files changed

+1127
-0
lines changed

3 files changed

+1127
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# AutoencoderKL training example
2+
3+
## Installing the dependencies
4+
5+
Before running the scripts, make sure to install the library's training dependencies:
6+
7+
**Important**
8+
9+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
10+
```bash
11+
git clone https://github.com/huggingface/diffusers
12+
cd diffusers
13+
pip install .
14+
```
15+
16+
Then cd in the example folder and run
17+
```bash
18+
pip install -r requirements.txt
19+
```
20+
21+
22+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
23+
24+
```bash
25+
accelerate config
26+
```
27+
28+
## Training on CIFAR10
29+
30+
Please replace the validation image with your own image.
31+
32+
```bash
33+
accelerate launch train_autoencoderkl.py \
34+
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
35+
--dataset_name=cifar10 \
36+
--image_column=img \
37+
--validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
38+
--num_train_epochs 100 \
39+
--gradient_accumulation_steps 2 \
40+
--learning_rate 4.5e-6 \
41+
--lr_scheduler cosine \
42+
--report_to wandb \
43+
```
44+
45+
## Training on ImageNet
46+
47+
```bash
48+
accelerate launch train_autoencoderkl.py \
49+
--pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
50+
--num_train_epochs 100 \
51+
--gradient_accumulation_steps 2 \
52+
--learning_rate 4.5e-6 \
53+
--lr_scheduler cosine \
54+
--report_to wandb \
55+
--mixed_precision bf16 \
56+
--train_data_dir /path/to/ImageNet/train \
57+
--validation_image ./image.png \
58+
--decoder_only
59+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
accelerate>=0.16.0
2+
bitsandbytes
3+
datasets
4+
huggingface_hub
5+
lpips
6+
numpy
7+
packaging
8+
Pillow
9+
taming_transformers
10+
torch
11+
torchvision
12+
tqdm
13+
transformers
14+
wandb
15+
xformers

0 commit comments

Comments
 (0)