|
| 1 | +import json |
| 2 | +import os |
| 3 | + |
| 4 | +from maxdiffusion import pyconfig |
| 5 | +from maxdiffusion.configuration_utils import ConfigMixin |
| 6 | +from maxdiffusion import __version__ |
| 7 | + |
| 8 | +class DummyConfigMixin(ConfigMixin): |
| 9 | + config_name = "config.json" |
| 10 | + |
| 11 | + def __init__(self, **kwargs): |
| 12 | + self.register_to_config(**kwargs) |
| 13 | + |
| 14 | +def test_to_json_string_with_config(): |
| 15 | + # Load the YAML config file |
| 16 | + config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml") |
| 17 | + |
| 18 | + # Initialize pyconfig with the YAML config |
| 19 | + pyconfig.initialize([None, config_path], unittest=True) |
| 20 | + config = pyconfig.config |
| 21 | + |
| 22 | + # Create a DummyConfigMixin instance |
| 23 | + dummy_config = DummyConfigMixin(**config.get_keys()) |
| 24 | + |
| 25 | + # Get the JSON string |
| 26 | + json_string = dummy_config.to_json_string() |
| 27 | + |
| 28 | + # Parse the JSON string |
| 29 | + parsed_json = json.loads(json_string) |
| 30 | + |
| 31 | + # Assertions |
| 32 | + assert parsed_json["_class_name"] == "DummyConfigMixin" |
| 33 | + assert parsed_json["_diffusers_version"] == __version__ |
| 34 | + |
| 35 | + # Check a few values from the config |
| 36 | + assert parsed_json["run_name"] == config.run_name |
| 37 | + assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path |
| 38 | + assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"] |
| 39 | + |
| 40 | + # The following keys are explicitly removed in to_json_string, so we assert they are not present |
| 41 | + assert "weights_dtype" not in parsed_json |
| 42 | + assert "precision" not in parsed_json |
0 commit comments