Skip to content
Open

DPO #30

Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
71c3014
Add base trainer for any accelerate model
shahbuland Sep 17, 2023
af358b7
Add PickaPic pipeline for DPO
shahbuland Sep 27, 2023
b140623
add skeleton for DPO trainer
shahbuland Sep 27, 2023
704a85c
Pipeline for DPO
Jan 23, 2024
68ec789
Allow for list of images instead of just list of np arrays for sample…
Jan 23, 2024
e70e537
Add sampler for DPO
Jan 23, 2024
2202d9f
Add method config for DPO
Jan 23, 2024
9304bac
Add DPO trainer initial version
Jan 23, 2024
1752621
basic debugs
Jan 24, 2024
b11e873
Remove streaming
Jan 24, 2024
9201d6a
minor bug fixes
Jan 24, 2024
35dd03d
Moved saving from DDPO trainer to base accelerate
Jan 24, 2024
e16526f
LoRA, refactorings, quick bug fixes
Jan 25, 2024
14fe254
small bug fixes
Jan 25, 2024
e121257
bug fixes
Jan 25, 2024
6d9e03d
Fix import errors and checkpointing
Jan 25, 2024
c2350cb
Add base model loss deviation to sampling as metric
Jan 26, 2024
765b9f6
Add base model loss deviation to trainer logging as metric
Jan 26, 2024
74012cc
Add non-lora training with memory saving options in config
Jan 26, 2024
ef91f92
some refactorings to sampling, add rmsprop
Jan 28, 2024
38847c5
Delete old DPO example, push new one
Jan 28, 2024
be05515
Rename DPO2 to DPO
Feb 13, 2024
e6023a3
Move DPO and DDPO sampler to their own files for better organiation
Feb 13, 2024
5253473
prepare for adding SDXL
Feb 13, 2024
54f6ec1
Fix issue with modularizing samplers
Feb 13, 2024
44f163d
Add SDXL support and reorganize config for model
Feb 13, 2024
565efe6
Remove mandatory gradient clipping and fix model saving with new config
Feb 13, 2024
4324932
Update dpo_pickapic.yml
shahbuland Feb 13, 2024
dde1265
Update dpo_pickapic.yml
shahbuland Feb 13, 2024
70f1827
Update dpo_pickapic.yml
shahbuland Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions configs/dpo_pickapic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
method:
name : "DPO"

model:
model_path: "stabilityai/stable-diffusion-2-1-base"
model_arch_type: "LDMUnet"
attention_slicing: True
xformers_memory_efficient: False
gradient_checkpointing: True

sampler:
guidance_scale: 7.5
num_inference_steps: 50

optimizer:
name: "adamw"
kwargs:
lr: 1.0e-5
weight_decay: 1.0e-4
betas: [0.9, 0.999]

scheduler:
name: "linear" # Name of learning rate scheduler
kwargs:
start_factor: 1.0
end_factor: 1.0

logging:
run_name: 'dpo_pickapic'
#wandb_entity: None
#wandb_project: None

train:
num_epochs: 500
num_samples_per_epoch: 256
batch_size: 4
sample_batch_size: 32
grad_clip: 1.0
checkpoint_interval: 50
tf32: True
suppress_log_keywords: "diffusers.pipelines,transformers"
54 changes: 54 additions & 0 deletions examples/DPO/download_pickapic_wds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from datasets import load_dataset
import requests
import os
from tqdm import tqdm
import tarfile
from multiprocessing import Pool, cpu_count

"""
This script takes the filtered version of the PickAPic prompt dataset
and downloads the associated images, then tars them. This tar file can then
be moved to S3 or loaded directly if needed. Number of samples can be specified
"""

n_samples = 1000
data_root = "./pickapic_sample"
url = "CarperAI/pickapic_v1_no_images_training_sfw"
n_cpus = cpu_count() # Detect the number of CPUs

base_name = os.path.basename(data_root).replace('.', '').replace('/', '')

def make_tarfile(output_filename, source_dir):
with tarfile.open(output_filename, "w") as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))

def download_image(args):
url, filename = args
response = requests.get(url)
with open(filename, 'wb') as f:
f.write(response.content)

if __name__ == "__main__":
ds = load_dataset("CarperAI/pickapic_v1_no_images_training_sfw")['train']
os.makedirs(data_root, exist_ok = True)

id_counter = 0
with Pool(n_cpus) as p:
for row in tqdm(ds, total = n_samples):
if id_counter >= n_samples:
break
if row['has_label']:
id_str = str(id_counter).zfill(8)
with open(os.path.join(data_root, f'{id_str}.prompt.txt'), 'w', encoding='utf-8') as f:
# Ensure the caption is in UTF-8 format
caption = row['caption'].encode('utf-8').decode('utf-8')
f.write(caption)
if row['label_0']:
p.map(download_image, [(row['image_0_url'], os.path.join(data_root, f'{id_str}.chosen.png')),
(row['image_1_url'], os.path.join(data_root, f'{id_str}.rejected.png'))])
else:
p.map(download_image, [(row['image_1_url'], os.path.join(data_root, f'{id_str}.chosen.png')),
(row['image_0_url'], os.path.join(data_root, f'{id_str}.rejected.png'))])
id_counter += 1

make_tarfile(f"{base_name}.tar", data_root)
14 changes: 14 additions & 0 deletions examples/DPO/train_pickapic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
sys.path.append("./src")

from drlx.pipeline.pickapic_wds import PickAPicPipeline
from drlx.trainer.dpo_trainer import DPOTrainer
from drlx.configs import DRLXConfig

pipe = PickAPicPipeline()
resume = False

config = DRLXConfig.load_yaml("configs/dpo_pickapic.yml")
trainer = DPOTrainer(config)

trainer.train(pipe)
16 changes: 16 additions & 0 deletions src/drlx/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ class DDPOConfig(MethodConfig):
buffer_size: int = 32 # Set to None to avoid using per prompt stat tracker
min_count: int = 16

@register_method("DPO")
@dataclass
class DPOConfig(MethodConfig):
"""
Config for DPO-related hyperparams

:param beta: Deviation from initial model
:type beta: float

:param ref_mem_strategy: Strategy for managing reference model on memory. By default, puts it in 16 bit.
:type ref_mem_strategy: str
"""
name : str = "DPO"
beta : float = 0.9
ref_mem_strategy : str = None # None or "half"

@dataclass
class TrainConfig(ConfigClass):
"""
Expand Down
11 changes: 11 additions & 0 deletions src/drlx/denoisers/ldm_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,16 @@ def forward(
encoder_hidden_states = text_embeds
).sample

@property
def device(self):
return self.unet.device

def enable_adapters(self):
if self.config.lora_rank:
self.unet.enable_adapters()

def disable_adapters(self):
if self.config.lora_rank:
self.unet.disable_adapters()


30 changes: 30 additions & 0 deletions src/drlx/pipeline/dpo_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import abstractmethod
from typing import Tuple, Callable

from PIL import Image

from drlx.pipeline import Pipeline

class DPOPipeline(Pipeline):
"""
Pipeline for training with DPO. Returns prompts, chosen images, and rejected images
"""
def __init__(self, *args):
super().__init__(*args)

@abstractmethod
def __getitem__(self, index : int) -> Tuple[str, Image.Image, Image.Image]:
pass

def make_default_collate(self, prep : Callable):
def collate(batch : Iterable[Tuple[str, Image.Image, Image.Image]]):
prompts = [d[0] for d in batch]
chosen = [d[1] for d in batch]
rejected = [d[2] for d in batch]

return prep(prompts, chosen, rejected)

return collate



65 changes: 65 additions & 0 deletions src/drlx/pipeline/pickapic_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from datasets import load_dataset
import io

from drlx.pipeline.dpo_pipeline import DPOPipeline

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def convert_bytes_to_image(image_bytes, id):
try:
image = Image.open(io.BytesIO(image_bytes))
image = image.resize((512, 512))
return image
except Exception as e:
print(f"An error occurred: {e}")

def create_train_dataset():
ds = load_dataset("yuvalkirstain/pickapic_v2",split='train')
ds = ds.filter(lambda example: example['has_label'] == True and example['label_0'] != 0.5)
return ds

class Collator:
def __call__(self, batch):
# Batch is list of rows which are dicts
image_0_bytes = [b['jpg_0'] for b in batch]
image_1_bytes = [b['jpg_1'] for b in batch]
uid_0 = [b['image_0_uid'] for b in batch]
uid_1 = [b['image_1_uid'] for b in batch]

label_0s = [b['label_0'] for b in batch]

for i in range(len(batch)):
if not label_0s[i]: # label_1 is 1 => jpg_1 is the chosen one
image_0_bytes[i], image_1_bytes[i] = image_1_bytes[i], image_0_bytes[i]
# Swap so image_0 is always the chosen one

prompts = [b['caption'] for b in batch]

images_0 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_0_bytes, uid_0)]
images_1 = [convert_bytes_to_image(i, id) for (i, id) in zip(image_1_bytes, uid_1)]

images_0 = torch.stack([transforms.ToTensor()(image) for image in images_0])
images_0 = images_0 * 2 - 1

images_1 = torch.stack([transforms.ToTensor()(image) for image in images_1])
images_1 = images_1 * 2 - 1

return {
"chosen_pixel_values" : images_0,
"rejected_pixel_values" : images_1,
"prompts" : prompts
}

class PickAPicDPOPipeline(DPOPipeline):
"""
Pipeline for training LDM with DPO
"""
def __init__(self):
self.train_ds = create_train_dataset()
self.dc = Collator()

def create_loader(self, **kwargs):
return DataLoader(self.train_ds, collate_fn = self.dc, **kwargs)
Loading