diff --git a/configs/training/lora_c_1b_bfloat16.yaml b/configs/training/lora_c_1b_bfloat16.yaml new file mode 100644 index 0000000..4bd56a5 --- /dev/null +++ b/configs/training/lora_c_1b_bfloat16.yaml @@ -0,0 +1,38 @@ +# GLOBAL STUFF +experiment_id: stage_c_1b_lora +checkpoint_path: /tmp/cascade/chk +output_path: /tmp/cascade/lora_sample +model_version: 1B + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 40 +image_size: 768 +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 4 +updates: 10000 +backup_every: 1000 +save_every: 100 +warmup_updates: 1 +# use_fsdp: True -> FSDP doesn't work at the moment for LoRA +use_fsdp: False + +# GDF +# adaptive_loss_weight: True + +# LoRA specific. 'No Defect Train Railcar Wheel' +module_filters: ['.attn'] +rank: 4 +train_tokens: + # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized + - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails + + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: file:/home/asutermo/cascade/data/dataset.tar +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_lite_bf16.safetensors diff --git a/core/__init__.py b/core/__init__.py index 03af283..c799c72 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -34,6 +34,8 @@ class Config(Base): wandb_project: str = None wandb_entity: str = None + single_gpu: bool = False + @dataclass() # not frozen, means that fields are mutable class Info(): # not inheriting from Base, because we don't want to enforce the default fields wandb_run_id: str = None @@ -141,6 +143,7 @@ def setup_config(self, config_file_path=None, config_dict=None, training=True) - return self.Config(training=training) def setup_ddp(self, experiment_id, single_gpu=False): + self.single_gpu = single_gpu if not single_gpu: local_rank = int(os.environ.get("SLURM_LOCALID")) process_id = int(os.environ.get("SLURM_PROCID")) @@ -297,7 +300,7 @@ def __call__(self, single_gpu=False): if self.is_main_node: print() - print("**STARTIG JOB WITH CONFIG:**") + print("**STARTING JOB WITH CONFIG:**") print(yaml.dump(self.config.to_dict(), default_flow_style=False)) print("------------------------------------") print() diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py index f36dc3f..9925f55 100644 --- a/core/templates/diffusion.py +++ b/core/templates/diffusion.py @@ -218,7 +218,8 @@ def models_to_save(self): return ['generator', 'generator_ema'] def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): - barrier() + if not self.single_gpu: + barrier() suffix = '' if suffix is None else suffix self.save_info(self.info, suffix=suffix) models_dict = models.to_dict() diff --git a/train/__init__.py b/train/__init__.py index 2a65075..484f69b 100755 --- a/train/__init__.py +++ b/train/__init__.py @@ -1,4 +1,5 @@ from .train_b import WurstCore as WurstCoreB from .train_c import WurstCore as WurstCoreC from .train_c_controlnet import WurstCore as ControlNetCore -from .train_c_lora import WurstCore as LoraCore \ No newline at end of file +from .train_c_lora import WurstCore as LoraCore + diff --git a/train/base.py b/train/base.py index 4e8a6ef..a732d24 100755 --- a/train/base.py +++ b/train/base.py @@ -195,6 +195,9 @@ class Config(DataCore.Config, WarpCore.Config): use_fsdp: bool = None + # Optimizer Params + use_8bit_adam: bool = None + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED class Info(WarpCore.Info): ema_loss: float = None @@ -310,7 +313,8 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op self.sample(models, data, extras) def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): - barrier() + if not self.single_gpu: + barrier() suffix = '' if suffix is None else suffix self.save_info(self.info, suffix=suffix) models_dict = models.to_dict() diff --git a/train/train_b.py b/train/train_b.py index 02b7b6e..8ab2da2 100755 --- a/train/train_b.py +++ b/train/train_b.py @@ -189,7 +189,7 @@ def dummy_context(): generator_ema = self.load_model(generator_ema, 'generator_ema') generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) - if self.config.use_fsdp: + if not self.single_gpu and self.config.use_fsdp: fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) if generator_ema is not None: @@ -209,7 +209,15 @@ def dummy_context(): ) def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + if self.config.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + else: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) optimizer = self.load_optimizer(optimizer, 'generator_optim', fsdp_model=models.generator if self.config.use_fsdp else None) return self.Optimizers(generator=optimizer) @@ -294,11 +302,13 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") + device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu") warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + device=device ) # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD # RUN TRAINING - warpcore() + use_single_gpu = torch.cuda.device_count() == 1 + warpcore(single_gpu=use_single_gpu) diff --git a/train/train_c.py b/train/train_c.py index 87c6608..1f748bb 100755 --- a/train/train_c.py +++ b/train/train_c.py @@ -174,7 +174,7 @@ def dummy_context(): generator_ema = self.load_model(generator_ema, 'generator_ema') generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) - if self.config.use_fsdp: + if not self.single_gpu and self.config.use_fsdp: fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) if generator_ema is not None: @@ -192,7 +192,15 @@ def dummy_context(): ) def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + if self.config.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + else: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) optimizer = self.load_optimizer(optimizer, 'generator_optim', fsdp_model=models.generator if self.config.use_fsdp else None) return self.Optimizers(generator=optimizer) @@ -256,11 +264,15 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") + + device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu") warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + device=device ) # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD # RUN TRAINING - warpcore() + use_single_gpu = torch.cuda.device_count() == 1 + warpcore(single_gpu=use_single_gpu) + diff --git a/train/train_c_controlnet.py b/train/train_c_controlnet.py index 59d58eb..075fbfa 100755 --- a/train/train_c_controlnet.py +++ b/train/train_c_controlnet.py @@ -15,9 +15,8 @@ from modules import EfficientNetEncoder from modules import StageC -from modules import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock from modules import Previewer -from modules import ControlNet, ControlNetDeliverer +from modules import ControlNet from modules import controlnet_filters from train.base import DataCore, TrainingCore @@ -26,7 +25,6 @@ from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy -from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy import functools from accelerate import init_empty_weights @@ -223,7 +221,7 @@ def dummy_context(): controlnet = self.load_model(controlnet, 'controlnet') controlnet.backbone.eval().requires_grad_(True) - if self.config.use_fsdp: + if not self.single_gpu and self.config.use_fsdp: fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) controlnet = FSDP(controlnet, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) @@ -235,7 +233,15 @@ def dummy_context(): ) def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: - optimizer = optim.AdamW(models.controlnet.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + if self.config.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + else: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) optimizer = self.load_optimizer(optimizer, 'controlnet_optim', fsdp_model=models.controlnet if self.config.use_fsdp else None) return self.Optimizers(generator=None, controlnet=optimizer) @@ -372,11 +378,14 @@ def sample(self, models: Models, data: WarpCore.Data, extras: Extras): if __name__ == '__main__': print("Launching Script") + device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu") warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + device=device ) warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD # RUN TRAINING - warpcore() + use_single_gpu = torch.cuda.device_count() == 1 + warpcore(single_gpu=use_single_gpu) + diff --git a/train/train_c_lora.py b/train/train_c_lora.py index 8b83eee..5aaffb6 100755 --- a/train/train_c_lora.py +++ b/train/train_c_lora.py @@ -15,7 +15,6 @@ from modules.effnet import EfficientNetEncoder from modules.stage_c import StageC -from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock from modules.previewer import Previewer from modules.lora import apply_lora, apply_retoken, LoRA, ReToken @@ -26,8 +25,6 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -import functools from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from contextlib import contextmanager @@ -166,7 +163,6 @@ def dummy_context(): yield None loading_context = dummy_context if self.config.training else init_empty_weights - with loading_context(): # Diffusion models if self.config.model_version == '3.6B': @@ -185,7 +181,7 @@ def dummy_context(): generator = generator.to(dtype).to(self.device) generator = self.load_model(generator, 'generator') - # if self.config.use_fsdp: + # if not self.single_gpu and self.config.use_fsdp: # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) @@ -239,7 +235,7 @@ def dummy_context(): lora = self.load_model(lora, 'lora') lora.to(self.device).train().requires_grad_(True) - if self.config.use_fsdp: + if not self.single_gpu and self.config.use_fsdp: # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken]) lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) @@ -252,7 +248,15 @@ def dummy_context(): ) def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: - optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + if self.config.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + else: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) optimizer = self.load_optimizer(optimizer, 'lora_optim', fsdp_model=models.lora if self.config.use_fsdp else None) return self.Optimizers(generator=None, lora=optimizer) @@ -320,11 +324,13 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") + device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu") warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + device=device ) warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD # RUN TRAINING - warpcore() + use_single_gpu = torch.cuda.device_count() == 1 + warpcore(single_gpu=use_single_gpu)