Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading bigger models is very slow using AutoModelForCausalLM.from_pretrained #562

Open
2 of 4 tasks
vibhas-singh opened this issue Nov 19, 2024 · 15 comments
Open
2 of 4 tasks
Labels
bug Something isn't working

Comments

@vibhas-singh
Copy link

vibhas-singh commented Nov 19, 2024

System Info

  • transformers version: 4.45.0
  • Platform: Linux-5.10.227-219.884.amzn2.x86_64-x86_64-with-glibc2.26
  • Python version: 3.10.14
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: Yes
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A10G

Who can help?

@ArthurZucker @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am spawning an g5.12xlarge GPU machine on AWS sagemaker and I am loading a locally saved model using this script:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

import torch
from transformers import AutoModelForCausalLM, AutoProcessor

model_id_or_path = "<local_path>"

model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

This is the problem with almost all the models I am trying - rhymes-ai/Aria can be used to reproduce it.

Expected behavior

The last line takes forever to load the model (>40-50 mins). I have observed the same behaviour for multiple other models as well.

Things I have tried/observed:

  • The behaviour is observed for the first time on an instance after an instance restart. Once I have loaded the model (by waiting for 40-50 mins) and then I restart the notebook kernel - all the subsequent model loads are very fast (almost instant).
  • However, if I restart the instance - the problem is again observed for the first load.
  • I suspected that it is taking time to figure out which GPU to put which layer on as I am using a cluster of 4 GPUs. For solving this, I saved the device_map of an already loaded model and passed it on to the loading constructor as device_map instead of using auto but it didn't not solve the issue.
  • I also suspected that it might be an issue with slow memory read/write speeds so I benchmarked that by loading the model on CPU - it loaded in an instant so memory and I/O is not a blocker.
  • The shards are following the same behaviour. For example, if I am trying to load a model having 10 shards and restarted the notebook after the first 4 shards are loaded - loading the model again takes very less time for those first 4 shards. I have verified that before the second model load - the usage is 0 so no leftover layers are remaining from the first load.
  • The time taken to load the shards is also not uniform - some shards take 3 minutes, others are taking more than 10-15 mins.
  • The model is saved in BF16 format already - so typecasting also doesnt seem to be an issue here.
@vibhas-singh vibhas-singh added the bug Something isn't working label Nov 19, 2024
@Rocketknight1
Copy link
Member

This feels like an accelerate issue, so pinging @SunMarc and @muellerzr once again, but yell if I should ping someone else!

@SunMarc
Copy link
Member

SunMarc commented Nov 21, 2024

It shouldn't take that much time to load these models. Can you try reproduce this on somehow on colab with a smaller model or it is only happening on your machine ? From what you said, loading a 7B might take you 10min which is way too long when using device_map

@vibhas-singh
Copy link
Author

vibhas-singh commented Nov 22, 2024

I have benchmarked the behaviour in colab using meta-llama/Llama-3.2-3B.
(Did not use a 7B model as some of the layers were being offloaded to CPU so I expected it to not give us an accurate idea.)

I am first downloading the model to my local drive using huggingface-cli

!huggingface-cli download "meta-llama/Llama-3.2-3B" --local-dir drive/MyDrive/models/llama-3b

Then I am loading the model using two different methods:

On CPU Machine

image image

On GPU Machine

image image

When I am loading the model after downloading it from HF hub to disk immediately, the model load is instant.
When I am loading the same model from the same disk after restarting the instance, its taking a lot of time.

This is creating issues as the model will always be saved locally in prod-like settings.

@ArthurZucker
Copy link

cc @Wauplin 🤗

@muellerzr
Copy link

muellerzr commented Nov 25, 2024

FYI this is not a transformers issue but is in-fact a safetensors issue (somehow). Take a look:

(I'm as confused as you are)
image

@Wauplin
Copy link
Contributor

Wauplin commented Nov 25, 2024

cc @Narsil

@vibhas-singh
Copy link
Author

vibhas-singh commented Dec 1, 2024

Hi @Narsil, @Wauplin & @ArthurZucker
Trying to gently highlight this issue again - apologies for the spam.

Please let me know if any fixes are possible for this.

I suspect the same thing happening when I am loading the model with vLLM as well - so there might be a broader issue.

@vibhas-singh
Copy link
Author

Hi Team, 🙋🏻

Circling back on this issue again - please let me know any potential fixes available or any direction to explore it further.

@vibhas-singh
Copy link
Author

@richardm1
Copy link

Circling back on this issue again - please let me know any potential fixes available or any direction to explore it further.

Have you seen this?

@ArthurZucker ArthurZucker transferred this issue from huggingface/transformers Jan 16, 2025
@vibhas-singh
Copy link
Author

This problem goes away when I converted the weights to .pth, so its indeed a safetensors issue.

@youkaichao
Copy link

This problem goes away when I converted the weights to .pth, so its indeed a safetensors issue.

I thought safetensors is strictly better than .pth . It's surprising that safetensors can be bad in performance.

@Narsil
Copy link
Collaborator

Narsil commented Feb 4, 2025

Problem

The issue here for everyone here seems to be mounted network disks.

A lot of providers, AWS, colab etc.. actually use network disks mounted as local which have super high latency.
The issue with network mounted disk, is that the OS doesn't know that it's a mounted disk, therefore it will issue the reads as if it was optimal for a local disk. In the case of a modern NVMe that means loading page per page (4k by default when hugepages are not activated).

Network is extremely slow, and therefore issuing many reads, is about the slowest possibly pattern you can do. There are ways to mitigate that when you are in charge of the mounting (by forcing buffered reads essentially).

The reason this doesn't occur when you just downloaded the model is that the file is actually still kept in RAM by the OS (the OS does pretty much everything possible to never read the disk, as it's the slowest possible operation it can do after network).
When the thing is already in RAM, then no network reads are issued, and because the file is memory mapped, the copies are extremely fast.

The issue also exists on Windows WSL where the filesystem (or the memory mapping I'm not sure) is also extremely slow.

Solution

1/ Is is very easy to solve that issue, just replace load_file into load(open(...).read()):

from safetensors.torch import load

with open(filename, "wb") as f:
    tensors = load(f.read())

What this will do is force a single read for the entire file, read everything in RAM first, and then move everything whereever you want. This is what loading PTH file do by default and why they are "faster" in this use case (and why we don't want to do this by default is following).

2/ You could even, just preamble open(...).read() and load normally. The only purpose of that read is to make the network loads more efficient. Once loaded into RAM, everything will memory page normally (provided the host has enough RAM to fit the model) It will still be slower than a good local disk, but not as slow as you're currently experiencing.

3/ Another option, which I also recommend is to actually save the models on the NVMe that AWS provides with the most expensive machines. Those are actually local and super fast, even if you actually have to read from disk, the loads are going to be quite fast.

Note

Now there's an issue with loading the file directly into RAM all the time. The reason is that most large models (30B+) are usually too large to fit a single GPU (it's always cheaper to run them on smaller hardware and multiple GPUs if they do).
When doing that, and sharding the model in whatever way, then reading the file directly in memory means every GPU needs to load the full model into CPU RAM. That means loading 8x the model into RAM which is extremely expensive.
When safetensors was created loading a 170B model went from 10mn to 1mn. Loading DeepSeek-R1 takes 70s on an AWS NVMEs.

Those speeds are only possible because we do not load everything at once, and load only what we need.

The problem are the direct issue from various parts of the system masquerading as others.
Mounted disks, masquerades itself as file system, it's convenient, but makes it impossible the host OS to make correct decisions
WSL masquerades as Linux(ext4), when it's in fact a Windows(ntfs).
Masquerading here just means, that safetensors doesn't really have a way of knowing what kind of drive/OS setup the thing is really in. Since it's impossible to know it's really hard to make a simple choice on behalf of users.

@Narsil
Copy link
Collaborator

Narsil commented Feb 4, 2025

An example of a downstream fix: huggingface/diffusers#10305

@richardm1
Copy link

The issue here for everyone here seems to be mounted network disks.

On Windows (at least) the problem exists when reading from anything that isn't a locally-attached NVMe drive. Even SATA SSD's with a typical throughput of ~520MB/s will read safetensors at perhaps 160-220MB/s. This is the best case and it gets worse from there.

Last summer I grabbed a Procmon64 capture of storage activity during a safetensor load and it wasn't pretty. NVMe is the ideal medium for dealing with this choppy, seeky, noncontiguous, and arguably pathological small-block I/O pattern; with everything else the micro (or milli)seconds of extra latency for each I/O really adds-up.

Spinning hard disk is the worst case as what should be a long, steady stream of large sequential I/Os (which hard drives handle quite well) is actually an I/O blender that confounds read-ahead caching and wastes time with excessive seek activity.

I maintain that good development practice will accommodate HDDs along with any reasonable amount of latency getting to/from secondary storage. It's clear that many devs are blessed with entirely local low-latency storage given the plethora of arm-twisting, gaslighting, and outright dismissal which surrounded this issue throughout the entirety of 2024.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

9 participants