Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions create_nvfp4_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.compress import compress, is_real_quantized
from modelopt.torch.quantization.config import CompressConfig
import modelopt.torch.opt as mto

from kandinsky import get_T2V_pipeline
from kandinsky import get_I2V_pipeline


def quant_mpt(model,mode = str):
_default_disabled_quantizer_cfg = {}

config = {
"quant_cfg": {
**_default_disabled_quantizer_cfg,

# Включаем только для всех nn.Linear (weights-only, NVFP4)
"nn.Linear": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {
-1: 16,
"type": "dynamic",
"scale_bits": (4, 3),
},
"enable": True,
"pass_through_bwd": False
},
"*input_quantizer": {"enable": False},
},
"*visual_embeddings.in_layer*": {"enable": False}, # Отключение квантования для visual_embeddings
"*time_embeddings*": {"enable": False},
"*text_embeddings*": {"enable": False},
"*text_embeddings.in_layer*": {"enable": False},
"*out_layer.modulation.out_layer*": {"enable": False},
"*out_layer.out_layer*": {"enable": False},
"*visual_transformer_blocks.*.feed_forward.out_layer.weight_quantizer*": {"num_bits": (4, 3), "axis": None},
"*visual_transformer_blocks.*.self_attention.out_layer.weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*visual_transformer_blocks.*.cross_attention.out_layer.weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*visual_transformer_blocks.*.cross_attention.to_key.weight_quantizer": {"num_bits": (4, 3), "axis": None},
},

"algorithm": "max",
}
# PTQ
model = mtq.quantize(model, config)

mtq.print_quant_summary(model)

ccfg = CompressConfig()
ccfg.compress = {"default": True}
compress(model, ccfg)
mto.save(model, "K5Pro_nvfp4.pth") # Сохраняем веса

print("Real-quantized?", is_real_quantized(model))

return model

if __name__ == "__main__":

pipe = get_T2V_pipeline(
device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"},
conf_path="/kandinsky-5/configs/k5_pro_t2v_5s_sft_sd.yaml",
model_type="base"
)

# pipe = get_I2V_pipeline(
# device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"},
# conf_path="/kandinsky-5/configs/k5_pro_i2v_5s_sft_sd.yaml",
# model_type="base"
# )

pipe.dit = quant_mpt(pipe.dit)





2 changes: 1 addition & 1 deletion kandinsky/models/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(

self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size)

#@torch.compile()
@torch.compile()
def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x,
text_rope_pos):
text_embed = self.text_embeddings(text_embed)
Expand Down
30 changes: 26 additions & 4 deletions kandinsky/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

import modelopt.torch.opt as mto

from huggingface_hub import hf_hub_download, snapshot_download
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
Expand Down Expand Up @@ -50,6 +52,8 @@ def get_T2V_pipeline(
quantized_qwen: bool = False,
text_token_padding: bool = False,
attention_engine: str = "auto",
model_type: str = "base", # "base" or "quantized"
quantized_model_path: str = "kandinsky-5/K5Pro_nvfp4.pth", # Путь к квантованным весам
) -> Kandinsky5T2VPipeline:
if not isinstance(device_map, dict):
device_map = {"dit": device_map, "vae": device_map, "text_embedder": device_map}
Expand Down Expand Up @@ -141,8 +145,16 @@ def get_T2V_pipeline(
no_cfg = True
set_magcache_params(dit, mag_ratios, num_steps, no_cfg)

state_dict = load_file(conf.model.checkpoint_path, device='cpu')
dit.load_state_dict(state_dict, assign=True)
if model_type == "base":
print(f"Loading BASE model weights from: {conf.model.checkpoint_path}")
state_dict = load_file(conf.model.checkpoint_path, device='cpu')
dit.load_state_dict(state_dict, assign=True)
elif model_type == "quantized":
print(f"Loading QUANTIZED model weights from: {quantized_model_path}")
dit = mto.restore(dit, quantized_model_path, map_location='cpu')
else:
raise ValueError(f"Unknown model_type: {model_type}. Must be 'base' or 'quantized'")
torch.cuda.empty_cache()

if world_size > 1:
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
Expand Down Expand Up @@ -188,6 +200,8 @@ def get_I2V_pipeline(
quantized_qwen: bool = False,
text_token_padding: bool = False,
attention_engine: str = "auto",
model_type: str = "base", # "base" or "quantized"
quantized_model_path: str = "kandinsky-5/K5Pro_nvfp4.pth", # Путь к квантованным весам
) -> Kandinsky5T2VPipeline:
if not isinstance(device_map, dict):
device_map = {"dit": device_map, "vae": device_map, "text_embedder": device_map}
Expand Down Expand Up @@ -278,8 +292,16 @@ def get_I2V_pipeline(
no_cfg = True
set_magcache_params(dit, mag_ratios, num_steps, no_cfg)

state_dict = load_file(conf.model.checkpoint_path, device='cpu')
dit.load_state_dict(state_dict, assign=True)
if model_type == "base":
print(f"Loading BASE model weights from: {conf.model.checkpoint_path}")
state_dict = load_file(conf.model.checkpoint_path, device='cpu')
dit.load_state_dict(state_dict, assign=True)
elif model_type == "quantized":
print(f"Loading QUANTIZED model weights from: {quantized_model_path}")
dit = mto.restore(dit, quantized_model_path, map_location='cpu')
else:
raise ValueError(f"Unknown model_type: {model_type}. Must be 'base' or 'quantized'")
torch.cuda.empty_cache()

if world_size > 1:
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
Expand Down