Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

possibly to avoid from_single_file loading in fp32 to save RAM #10679

Open
asomoza opened this issue Jan 29, 2025 · 6 comments
Open

possibly to avoid from_single_file loading in fp32 to save RAM #10679

asomoza opened this issue Jan 29, 2025 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@asomoza
Copy link
Member

asomoza commented Jan 29, 2025

Describe the bug

When loading a model using from_single_file(), the RAM usage is really high possibly because the weights are loaded in FP32 before conversion.

Reproduction

import threading
import time

import psutil
import torch
from huggingface_hub import hf_hub_download

from diffusers import UNet2DConditionModel


filename = hf_hub_download("stable-diffusion-v1-5/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.safetensors")

stop_monitoring = False


def log_memory_usage():
    process = psutil.Process()
    mem_info = process.memory_info()
    return mem_info.rss / (1024**2)  # Convert to MB


def monitor_memory(interval, peak_memory):
    while not stop_monitoring:
        current_memory = log_memory_usage()
        peak_memory[0] = max(peak_memory[0], current_memory)
        time.sleep(interval)


def load_model(filename, dtype):
    global stop_monitoring

    peak_memory = [0]  # Use a list to store peak memory so it can be updated in the thread
    initial_memory = log_memory_usage()
    print(f"Initial memory usage: {initial_memory:.2f} MB")

    monitor_thread = threading.Thread(target=monitor_memory, args=(0.01, peak_memory))
    monitor_thread.start()

    start_time = time.time()
    UNet2DConditionModel.from_single_file(filename, torch_dtype=dtype)
    end_time = time.time()

    stop_monitoring = True
    monitor_thread.join()  # Wait for the monitoring thread to finish

    print(f"Peak memory usage: {peak_memory[0]:.2f} MB")
    print(f"Time taken: {end_time - start_time:.2f} seconds")
    final_memory = log_memory_usage()
    print(f"Final memory usage: {final_memory:.2f} MB")


load_model(filename, torch.float8_e4m3fn)

Logs

Initial memory usage: 737.19 MB
Peak memory usage: 4867.43 MB
Time taken: 0.92 seconds
Final memory usage: 1578.99 MB

System Info

not relevant here

Who can help?

@DN6

@asomoza asomoza added the bug Something isn't working label Jan 29, 2025
@yiyixuxu
Copy link
Collaborator

@asomoza can you test under this PR? #10604

@asomoza
Copy link
Member Author

asomoza commented Jan 29, 2025

@yiyixuxu changed to windows and a mobile 4090 since the original issue was with windows.

Without the PR:

Initial memory usage: 548.29 MB
Peak memory usage: 4667.13 MB
Time taken: 1.49 seconds
Final memory usage: 1387.07 MB

With the PR:

Initial memory usage: 548.34 MB
Peak memory usage: 4668.90 MB
Time taken: 3.08 seconds
Final memory usage: 1388.89 MB

So RAM usage is almost the same, and not related to this issue, but that PR doubles the time for loading the model which is not good.

ccing @SunMarc just in case

@Nerogar
Copy link
Contributor

Nerogar commented Jan 29, 2025

Adding a bit more context from the original discussion:
The overhead is mostly caused by this issue: huggingface/safetensors#542, but there are other problems that make it harder reduce RAM usage when loading a model. (See the feature request below)

The conversion itself uses a lot of RAM. Combined with the fact that there is no way (that I know of) to load the model in its original format from the file, there will be an overhead in most situations. See this example:

I'm using this file https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors, because it shows the issue a bit better. Instead of UNet2DConditionModel it needs to be loaded as FluxTransformer2DModel. The file mostly contains weights in bf16 format:

loading in fp8 format from a bf16 file: load_model(filename, torch.float8_e4m3fn)

Initial memory usage: 562.95 MB
Peak memory usage: 41511.62 MB
Time taken: 8.86 seconds
Final memory usage: 11774.39 MB

loading in fp32 format from a bf16 file: load_model(filename, torch.float8_e4m3fn)

Initial memory usage: 563.26 MB
Peak memory usage: 75561.89 MB
Time taken: 10.72 seconds
Final memory usage: 45595.88 MB

loading in bf6 format from a bf16 file: load_model(filename, torch.bfloat16) No conversion happens in this case, which is good. It also drastically reduces the peak RAM usage.

Initial memory usage: 562.68 MB
Peak memory usage: 14339.55 MB
Time taken: 2.94 seconds
Final memory usage: 7459.44 MB (not sure if these numbers are correct. It probably didn't load all the tensors because there is no read access. The final usage should be closer to 24GB)

loading in an unspecified format from a bf16 file: load_model(filename, None) The model is now loaded in fp32 format, which triggers a conversion, using a lot of RAM again. I would expect that the model is loaded with its original weights in bf16 format.

Initial memory usage: 562.57 MB
Peak memory usage: 75545.73 MB
Time taken: 10.43 seconds
Final memory usage: 45595.22 MB

Feature request for a new parameter to remove any conversion overhead

Additionally it would be great to have an option to not load the weights at all. This can be done by removing any read access to the tensors. The safetensors library already supports lazy tensor loading out of the box. Only tensors with a read access are actually loaded from the file. At the moment this is triggered by the .to() call that converts the weights. Having this option would make it possible to manually convert each tensor to a custom data type without any overhead. (apart from the issue linked above)

@elismasilva
Copy link
Contributor

Ive tested load this flux_dev model. On my machine when I have 46GB of free memory. This loading method load_from_single_file() when internally it loads the state_dict it loads it in bfloat16 in memory before converting. So my memory dropped to 32GB. Then in the step where it converted this to the diffusers format my memory dropped again to 20GB. And finally when it takes this and throws it in the meta device, only then will it convert it to FP8 so my memory drops to less than 4GB free and then it frees up the memory again.

@al-swaiti
Copy link

al-swaiti commented Feb 3, 2025

the ram used depend on file size loaded , so u must convert model to smaller size then load it ! or add command to free ram after loaded

@SunMarc
Copy link
Member

SunMarc commented Feb 4, 2025

So RAM usage is almost the same, and not related to this issue, but that PR doubles the time for loading the model which is not good.

ccing @SunMarc just in case

It was due to a small mistake from me ! Sorry for that, I fixed it in the latest commit. Also the PR should only speed up loading for diffusion models for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants