Skip to content

Commit f41036d

Browse files
authored
Wan checkpointing (#246)
* Conversion of dataclass objects and others, not raise error * Add test for config conversion for checkpointing * Skip jax.distributed initialize in utnittest * Temporary skip pipeline tests
1 parent 95afb77 commit f41036d

File tree

5 files changed

+61
-6
lines changed

5 files changed

+61
-6
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
ruff check .
5555
- name: PyTest
5656
run: |
57-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
57+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -x
5858
# add_pull_ready:
5959
# if: github.ref != 'refs/heads/main'
6060
# permissions:

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,16 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ri
6060
flash_min_seq_length: 4096
6161
dropout: 0.1
6262

63-
flash_block_sizes: {}
63+
flash_block_sizes: {
64+
"block_q" : 1024,
65+
"block_kv_compute" : 256,
66+
"block_kv" : 1024,
67+
"block_q_dkv" : 1024,
68+
"block_kv_dkv" : 1024,
69+
"block_kv_dkv_compute" : 256,
70+
"block_q_dq" : 1024,
71+
"block_kv_dq" : 1024
72+
}
6473
# Use on v6e
6574
# flash_block_sizes: {
6675
# "block_q" : 3024,

src/maxdiffusion/configuration_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import Any, Dict, Tuple, Union
2727
from . import max_logging
2828
import numpy as np
29+
from dataclasses import asdict, is_dataclass
2930

3031
from huggingface_hub import create_repo, hf_hub_download
3132
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
@@ -54,16 +55,17 @@ class CustomEncoder(json.JSONEncoder):
5455
"""
5556

5657
def default(self, o):
57-
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
5858
if isinstance(o, type(jnp.dtype("bfloat16"))):
5959
return str(o)
60-
# Add fallbacks for other numpy types if needed
6160
if isinstance(o, np.integer):
6261
return int(o)
6362
if isinstance(o, np.floating):
6463
return float(o)
65-
# Let the base class default method raise the TypeError for other types
66-
return super().default(o)
64+
if is_dataclass(o):
65+
return asdict(o)
66+
else:
67+
max_logging.log(f"Warning: {o} of type {type(o)} is not JSON serializable")
68+
return None
6769

6870

6971
class FrozenDict(OrderedDict):
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import json
2+
import os
3+
4+
from maxdiffusion import pyconfig
5+
from maxdiffusion.configuration_utils import ConfigMixin
6+
from maxdiffusion import __version__
7+
8+
class DummyConfigMixin(ConfigMixin):
9+
config_name = "config.json"
10+
11+
def __init__(self, **kwargs):
12+
self.register_to_config(**kwargs)
13+
14+
def test_to_json_string_with_config():
15+
# Load the YAML config file
16+
config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml")
17+
18+
# Initialize pyconfig with the YAML config
19+
pyconfig.initialize([None, config_path], unittest=True)
20+
config = pyconfig.config
21+
22+
# Create a DummyConfigMixin instance
23+
dummy_config = DummyConfigMixin(**config.get_keys())
24+
25+
# Get the JSON string
26+
json_string = dummy_config.to_json_string()
27+
28+
# Parse the JSON string
29+
parsed_json = json.loads(json_string)
30+
31+
# Assertions
32+
assert parsed_json["_class_name"] == "DummyConfigMixin"
33+
assert parsed_json["_diffusers_version"] == __version__
34+
35+
# Check a few values from the config
36+
assert parsed_json["run_name"] == config.run_name
37+
assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path
38+
assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"]
39+
40+
# The following keys are explicitly removed in to_json_string, so we assert they are not present
41+
assert "weights_dtype" not in parsed_json
42+
assert "precision" not in parsed_json

src/maxdiffusion/tests/input_pipeline_interface_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from absl.testing import absltest
2424

2525
import numpy as np
26+
import pytest
2627
import tensorflow as tf
2728
import tensorflow.experimental.numpy as tnp
2829
import jax
@@ -69,6 +70,7 @@ class InputPipelineInterface(unittest.TestCase):
6970
def setUp(self):
7071
InputPipelineInterface.dummy_data = {}
7172

73+
@pytest.mark.skip(reason="Debug segfault")
7274
def test_make_dreambooth_train_iterator(self):
7375

7476
instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class"

0 commit comments

Comments
 (0)