Skip to content

Model init with HuggingFace model #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
neeldani opened this issue Dec 16, 2024 · 17 comments
Open

Model init with HuggingFace model #743

neeldani opened this issue Dec 16, 2024 · 17 comments
Assignees
Labels
bug Something isn't working huggingface integration module: checkpoint question Further information is requested

Comments

@neeldani
Copy link

neeldani commented Dec 16, 2024

I am writing a simple script to run FSDP2 (fully_shard) on the pythia-1b model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try with pythia-2.8b model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:

model_name = "EleutherAI/pythia-14m"
    
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_name)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)

    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)

    model = load_checkpoint_and_dispatch(
        model, path_to_safe_tensors
    )

This is not very straightforward since the shards expect DTensors when the weights are being loaded via load_checkpoint_and_dispatch. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.

@awgu
Copy link
Collaborator

awgu commented Dec 16, 2024

cc: @weifengpy @mori360

@neeldani
Copy link
Author

👋 Gentle bump on this - mainly to see if there is some workaround for the above issue 👀

@mori360
Copy link
Contributor

mori360 commented Dec 20, 2024

However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB).

It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)

I get an OOM on my device when I try with pythia-2.8b model

Could you give more details on the safe_tensors as I could repro the huge memory cost.
Also, could you give a device flow so that I could follow up when you switch you devices to gpu.

@neeldani
Copy link
Author

neeldani commented Dec 23, 2024

It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)

I see. Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the full_state_dict

Could you give more details on the safe_tensors as I could repro the huge memory cost.

I downloaded the model.safetensors for the pythia-1b model from here. These weights are not sharded

Also, could you give a device flow so that I could follow up when you switch you devices to gpu.

I am trying to mimic TorchTitan's implementation but with a HuggingFace model

  1. Load the empty model on the meta device
  2. Apply fsdp, move sharded weights to the respective GPUs and materialise the weights
  3. re-initialise the sharded weights on each GPU

This is a simple repro of my implementation which can be run using:

torchrun --nnodes=1 --nproc_per_node=2 reproduce.py

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
    num_params = sum(p.numel() for p in model.parameters())
    if exclude_embedding:
        num_params -= model.tok_embeddings.weight.numel()
    return num_params

def setup(local_rank, world_size):
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    init_process_group("nccl", rank=local_rank, world_size=world_size)

def load():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    setup(local_rank, world_size)

    model_name = "EleutherAI/pythia-2.8b"
    config = AutoConfig.from_pretrained(model_name)
    
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
    
    if local_rank == 0:
        print("Load models with empty weights")
        print("Device: ", model.device)
        print("Params: ", get_num_params(model))
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))

    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)

    if local_rank == 0:
        print("Applied FSDP to the model")
        print("Device: ", model.device)
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))
        print("# of params: ", get_num_params(model))

    model.to_empty(device='cuda')

    if local_rank == 0:
        print("Materialized the sharded tensors")
        print("Device: ", model.device)
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))
        print("# of params: ", get_num_params(model))

    model = load_checkpoint_and_dispatch(model, "model.safetensors", device_map="auto", no_split_module_classes="GPTNeoXLayer")

if __name__ == "__main__":
    load()

The flow is very similar to that of TorchTitan's except that TorchTitan makes an explicit call to re-initialise the weights after materialising them. Since I wish to load weights from a pretrained HF model, its a bit challenging. The above code throws an error where I call load_checkpoint_and_dispatch since the model expects DTensors as inputs.

@mori360
Copy link
Contributor

mori360 commented Dec 27, 2024

Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the full_state_dict

torch.distributed.checkpoint.state_dict.set_model_state_dict could load the sharded model without loading the full_state_dict at one time as it conducts loading param by param to avoid the memory peak(to help avoid the OOM).

However, accelerate.load_checkpoint_and_dispatch does not support sharded model right now, without condition for param_cls.__name__ in [''DTensor"] to conduct distribute

@fegin Please correct me if I'm wrong. Also, shall we update model.init_weight() in torchtitan in the process from model.init_weight() to checkpoint.load() to to init weight param by param?

@tianyu-l tianyu-l added question Further information is requested bug Something isn't working labels Jan 7, 2025
@fegin
Copy link
Contributor

fegin commented Jan 8, 2025

Yes, @mori360, as you have implemented this feature, OOM should be able to avoid with set_model_state_dict. But we will need the state_dict to be loaded with DCP and set_model_state_dict.

@Hannibal046
Copy link

Hi, any progress here? What is the best practice to continue pretrain a HF model with torchtitan?

@yzhangcs
Copy link
Contributor

yzhangcs commented Feb 18, 2025

@neeldani Regarding your orginal issue, for now, the easiest approach would be to:

  1. Convert your HF model weights to the DCP format and save them in <path>/checkpoint/step-0. You can follow the instructions in this guide: How to Convert a LLaMA 3 Checkpoint for Use in TorchTitan. Replace <path> with your desired save location.
  2. Once the weights are converted, you can resume training directly by setting --training.load_step 0, similar to how you would with a seed checkpoint.

Does this make sense? @mori360 @fegin @tianyu-l @huyiwen, please correct me if I missed anything.

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 19, 2025

Thanks @yzhangcs

What is the best practice to continue pretrain a HF model with torchtitan?

I think the key thing to do is to convert a HF checkpoint into a DCP checkpoint, like what this script does #305 (comment)

I heard that DCP is going to support HF checkpointing format, but it may take some time to happen.
related PR for the non-distributed use case: pytorch/pytorch#146352
cc: @fegin @kwen2501 to confirm

@yzhangcs
Copy link
Contributor

@tianyu-l I just wrote one for medium/small-sized models https://github.com/fla-org/flame/blob/main/convert_hf_to_dcp.py
like https://github.com/pytorch/torchtitan/blob/main/scripts/convert_llama_to_dcp.py.
I’m using the converted DCPs to finetune the Qwen model on finweb-edu, and everything appears to be working as expected so far.

Image

@mingdianliu
Copy link

mingdianliu commented Apr 5, 2025

@neeldani @fegin @yzhangcs @awgu @Hannibal046 @tianyu-l @mori360

Dear All,

Thanks for making FSDP2 compatible with Huggingface model. However, I meet with an issue while running the reproduce code. Just want to know if you have any insights for this issue.

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
    num_params = sum(p.numel() for p in model.parameters())
    if exclude_embedding:
        num_params -= model.tok_embeddings.weight.numel()
    return num_params

def setup(local_rank, world_size):
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    init_process_group("nccl", rank=local_rank, world_size=world_size)

def load():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    setup(local_rank, world_size)

    model_name = "EleutherAI/pythia-2.8b"
    config = AutoConfig.from_pretrained(model_name)
    
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
    
    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)
    model.to_empty(device='cuda')


if __name__ == "__main__":
    load()

The error is below:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/NCCL/report_issue.py](/NCCL/report_issue.py)", line 41, in <module>
[rank0]:     load()
[rank0]:   File "/workspace/NCCL/report_issue.py](/NCCL/report_issue.py)", line 34, in load
[rank0]:     fully_shard(module)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/contract.py", line 107, in wrapper
[rank0]:     updated = func(module, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 114, in fully_shard
[rank0]:     _move_states_to_device(params, buffers, device, mesh_info)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_init.py", line 143, in _move_states_to_device
[rank0]:     tensor.data = [tensor.to](http://tensor.to/)(device)
[rank0]: NotImplementedError: Cannot copy out of meta tensor; no data!

Python command:
torchrun --nnodes=1 --nproc_per_node=8 reproduce.py

@awgu
Copy link
Collaborator

awgu commented Apr 5, 2025

@mingdianliu which version of PyTorch are you using? maybe you need a newer version

@mingdianliu
Copy link

mingdianliu commented Apr 5, 2025

@mingdianliu which version of PyTorch are you using? maybe you need a newer version

@awgu Thank you very much! After upgrading pytorch to 2.6.0, the code is working on my side. I have one more follow-up question.

I have followed your instruction to convert HF ckpt to DCP ckpt. However, it takes too long time to load DCP ckpt (540 seconds for Qwen2-VL-7B model on 2 nodes 16 GPUs) with torch.distributed.checkpoint.load(state_dict, checkpoint_id=None, storage_reader=None). Is there any better method I can leverage to accelerate the ckpt loading process?

In the code, I am using model.load_state_dict() to load the state_dict(), which has a comparable latency as set_model_state_dict().

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.device_mesh import init_device_mesh

from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq
from transformers import (
    AutoConfig,
    Qwen2VLForConditionalGeneration,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer, Qwen2VLVisionBlock


def load():
    
    distributed_backend = "nccl" # gloo for cpu
    init_process_group(distributed_backend)

    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    model_name = "Qwen/Qwen2-VL-2B-Instruct"
    revision = "895c3a49bc3fa70a340399125c650a463535e71c"
    # model_name = "Qwen/Qwen2-VL-7B-Instruct"
    # revision = "a28a094eb66a9f2ac70eef346f040d8a79977472"
    # model_name = "Qwen/Qwen2-VL-72B-Instruct"
    # revision = "f9b556a74d58e6d9915f73227c21045c87342b42"

    config = AutoConfig.from_pretrained(
        model_name, 
        revision=revision, 
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2")
    
    device_mesh = init_device_mesh("cuda", (world_size,))

    with torch.device("meta"):
        model = AutoModelForVision2Seq.from_config(config)

    for module in model.modules():
        if isinstance(module, Qwen2VLDecoderLayer):
            fully_shard(module, mesh=device_mesh, reshard_after_forward=True)
    
    model = fully_shard(model, mesh=device_mesh, reshard_after_forward=True)

    model.to_empty(device='cuda')

    model_state_dict = model.state_dict()
    model_dir = "path_to_DCP_ckpt_dir/2B"

    print("start torch.distributed.checkpoint.load")
    fs_storage_reader = FileSystemReader(model_dir)
    torch.distributed.checkpoint.load(
        state_dict=model_state_dict,
        storage_reader=fs_storage_reader,
        )

    model.load_state_dict(model_state_dict)

    print("Model loaded")

if __name__ == "__main__":
    load()
    destroy_process_group()

Python command:
torchrun --nnodes=2 --nproc_per_node=8 reproduce.py

Actually, I also have a shot at model = load_checkpoint_and_dispatch(model, checkpoint=model_dir, device_map="auto", no_split_module_classes=["Qwen2VLDecoderLayer"], dtype=torch.bfloat16,). But I will run into the following error:

[rank7]: Traceback (most recent call last):                                                                                                                                                       
[rank7]:   File "/workspace/NCCL/reproduce.py", line 204, in <module>
[rank7]:     load()
[rank7]:   File "/workspace/NCCL/reproduce.py", line 159, in load
[rank7]:     model = load_checkpoint_and_dispatch(
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/big_modeling.py", line 620, in load_checkpoint_and_dispatch
[rank7]:     load_checkpoint_in_model(
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/modeling.py", line 1982, in load_checkpoint_in_model
[rank7]:     set_module_tensor_to_device(
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/modeling.py", line 377, in set_module_tensor_to_device
[rank7]:     new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
[rank7]:     return disable_fn(*args, **kwargs)
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank7]:     return fn(*args, **kwargs)
[rank7]: TypeError: DTensor.__new__() missing 1 required positional argument: 'spec'

@mingdianliu
Copy link

Dear community,

Thanks for your replies. This issue has been resolved. The loading process is pretty slow due to a low-performing dish in which I save the DCP checkpoint. After switching to a good disk, 72B model can be loaded although the loading time is a little long. I will have a try on optimizing the loading time. If there is any optimization progress, I will post it here.

@fegin
Copy link
Contributor

fegin commented Apr 15, 2025

@mingdianliu We are exploring an offline resharding converter to speed up the loading time, #1104.

@mingdianliu
Copy link

mingdianliu commented Apr 21, 2025

@neeldani @fegin @yzhangcs @awgu @Hannibal046 @tianyu-l @mori360
Dear everyone,

I have one more issue while working on fine-tuning the Qwen2-VL model using fully_shard(). I noticed that GPU memory usage stays high (around 50GB to 60GB) even as I scale up the number of GPUs. Besides, it will run into OOM when I try to fine tune 72B model with 128 GPUs.

I'm wondering if there might be any issues with my code or configuration. I'd really appreciate any insights or suggestions you might have. Thanks in advance!

My code:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoModelForVision2Seq, AutoConfig
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import numpy as np
from PIL import Image
import io
import logging
import os

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.device_mesh import init_device_mesh
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer, Qwen2VLVisionBlock
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed._composable.fsdp import (
    CPUOffloadPolicy,
    fully_shard,
    MixedPrecisionPolicy,
)


# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# init dist
distributed_backend = "nccl" # gloo for cpu
dist.init_process_group(distributed_backend)

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)


# model_name = "Qwen/Qwen2-VL-2B-Instruct"
# revision = "895c3a49bc3fa70a340399125c650a463535e71c"
model_name = "Qwen/Qwen2-VL-7B-Instruct"
revision = "a28a094eb66a9f2ac70eef346f040d8a79977472"
# model_name = "Qwen/Qwen2-VL-72B-Instruct"
# revision = "f9b556a74d58e6d9915f73227c21045c87342b42"

dataset_id = "HuggingFaceM4/ChartQA"
processor = Qwen2VLProcessor.from_pretrained(model_name, 
                                             revision=revision,
                                             )


# Configuration
class Config:
    dataset_id = "HuggingFaceM4/ChartQA"
    output_dir = "/tmp_ckpt"
    batch_size = 2
    num_epochs = 3
    learning_rate = 5e-5
    max_seq_length = 512
    lora_rank = 32
    lora_alpha = 64
    lora_dropout = 0.1
    device = "cuda" if torch.cuda.is_available() else "cpu"




system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

def format_data(sample):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                },
                {
                    "type": "text",
                    "text": sample["query"],
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["label"][0]}],
        },
    ]

# Training function
def train_model(model, train_loader, optimizer, config):
    model.train()
    total_steps = len(train_loader) * config.num_epochs
    step = 0

    scaler = torch.amp.GradScaler("cuda", enabled=True)

    for epoch in range(config.num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(train_loader):

            inputs, labels = batch
            inputs = inputs.to(config.device)
            labels = labels.to(config.device)

            # Mixed precision training
            loss = model(**inputs, labels=labels).loss
            loss.backward() # no scaler
            optimizer.step()
            optimizer.zero_grad()
            
            step += 1
            logger.info(f"Epoch {epoch+1}/{config.num_epochs}, Step {step}/{total_steps}, Loss: {loss.item():.4f}")

            del loss



# Create a data collator to encode text and image pairs
def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [
        processor.apply_chat_template(example, tokenize=False) for example in examples
    ]  # Prepare texts for processing
    image_inputs = [process_vision_info(example)[0] for example in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )  # Encode texts and images into tensors

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels

    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):  # Check if the processor is Qwen2VLProcessor
        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

    # Mask image token IDs in the labels
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    return batch, labels



# Main function
def main():

    config = Config()

    # Load model and processor
    logger.info("Loading model and processor...")

    hf_config = AutoConfig.from_pretrained(
                model_name, 
                revision=revision, 
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                )

    with torch.device("meta"):
        model = AutoModelForVision2Seq.from_config(hf_config, torch_dtype=torch.bfloat16)

    mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, 
                                   reduce_dtype=torch.bfloat16, 
                                   output_dtype=torch.bfloat16, 
                                   cast_forward_inputs=True)
    offload_policy = CPUOffloadPolicy(pin_memory=False)

    # apply FSDP2
    device_mesh = init_device_mesh("cuda", (world_size,))
    for module in model.modules():
        if isinstance(module, Qwen2VLDecoderLayer):
            fully_shard(module, 
                        mesh=device_mesh, 
                        reshard_after_forward=True,
                        mp_policy=mp_policy,
                        # offload_policy=offload_policy,
                        )
    
    model = fully_shard(model, 
                        mesh=device_mesh, 
                        reshard_after_forward=True,
                        mp_policy=mp_policy,
                        # offload_policy=offload_policy,
                        )

    model.to_empty(device='cuda')

    model_state_dict = model.state_dict()

    model_dir = "/cache/fsdp_test/72B_8_files"

    # load qwen2-vl model
    dcp.load(
        state_dict=model_state_dict,
        checkpoint_id=model_dir,
        planner=DefaultLoadPlanner(allow_partial_load=True),
    )

    model = model.to(torch.bfloat16).cuda()
    

    # Load dataset
    logger.info("Loading dataset...")

    train_dataset, eval_dataset, test_dataset = load_dataset(
        config.dataset_id, split=['train[:10%]', 'val[:10%]', 'test[:10%]'])
    train_dataset = [format_data(sample) for sample in train_dataset]

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        collate_fn=collate_fn,
        shuffle=True,
    )

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Train
    logger.info("Starting training...")
    train_model(model, train_dataloader, optimizer, config)


if __name__ == "__main__":
    main()
    destroy_process_group()
    logger.info("Training completed.")

Running command:
torchrun --nnodes=2 --nproc_per_node=8 qwenvl_train_fsdp.py
torchrun --nnodes=4 --nproc_per_node=8 qwenvl_train_fsdp.py
torchrun --nnodes=8 --nproc_per_node=8 qwenvl_train_fsdp.py

The following is the screenshot of the result of nvidia-smi:

16 GPU:

Image

32 GPU:

Image

64 GPU:

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working huggingface integration module: checkpoint question Further information is requested
Projects
None yet
Development

No branches or pull requests

8 participants