diff --git a/create_fp8_full.py b/create_fp8_full.py new file mode 100644 index 0000000..ce06c96 --- /dev/null +++ b/create_fp8_full.py @@ -0,0 +1,58 @@ +import torch + +from torchao.quantization import ( + quantize_, + ModuleFqnToConfig, + Float8DynamicActivationFloat8WeightConfig, +) + +from kandinsky import get_video_pipeline + + +def build_dit_fp8_config() -> ModuleFqnToConfig: + base_cfg = Float8DynamicActivationFloat8WeightConfig() + + # Всё остальное квантуем base_cfg + module_cfg = {"_default": base_cfg} + + + disabled_fqns = [ + "time_embeddings.in_layer", + "time_embeddings.out_layer", + "text_embeddings.in_layer", + "visual_embeddings.in_layer", + "out_layer.modulation.out_layer", + "out_layer.out_layer", + ] + + for i in range(4): + + disabled_fqns.append( + f"text_transformer_blocks.{i}.feed_forward.out_layer" + ) + + for i in range(60): + disabled_fqns.append( + f"visual_transformer_blocks.{i}.feed_forward.out_layer" + ) + + for fqn in set(disabled_fqns): + module_cfg[fqn] = None + + return ModuleFqnToConfig(module_cfg) + + +if __name__ == "__main__": + + pipe = get_video_pipeline( + device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"}, + conf_path="/data/kandinsky-5/configs/k5_pro_t2v_5s_sft_sd.yaml", + model_type="base", + mode="t2v" + ) + + ao_cfg = build_dit_fp8_config() + quantize_(pipe.dit, ao_cfg) + torch.save(pipe.dit.state_dict(), "/data/kandinsky-5/weights/K5_pro_5s_ao.pt") + + diff --git a/kandinsky/utils.py b/kandinsky/utils.py index 229eb19..b4b622f 100644 --- a/kandinsky/utils.py +++ b/kandinsky/utils.py @@ -50,6 +50,8 @@ def get_video_pipeline( text_token_padding: bool = False, attention_engine: str = "auto", mode: str = None, + model_type: str = "base", # "base" or "fp8" + quantized_model_path: str = "/data/kandinsky-5/weights/K5_pro_5s_ao.pt", # Путь к квантованным весам ): if not isinstance(device_map, dict): device_map = {"dit": device_map, "vae": device_map, "text_embedder": device_map} @@ -136,8 +138,14 @@ def get_video_pipeline( no_cfg = True set_magcache_params(dit, mag_ratios, num_steps, no_cfg) - state_dict = load_file(conf.model.checkpoint_path, device='cpu') - dit.load_state_dict(state_dict, assign=True) + if model_type == "base": + print(f"Loading BASE model weights from: {conf.model.checkpoint_path}") + state_dict = load_file(conf.model.checkpoint_path, device='cpu') + dit.load_state_dict(state_dict, assign=True) + elif model_type == "fp8": + print(f"Loading QUANTIZED FP8 Full model weights from: {quantized_model_path}") + state_dict = torch.load(quantized_model_path) + dit.load_state_dict(state_dict, assign=True) if not offload and world_size == 1: dit = dit.to(device_map["dit"])