diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index faed3d58..3d0a4915 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -118,6 +118,7 @@ def __init__( self.llama_model = LlamaForCausalLM.from_pretrained( llama_model, torch_dtype=torch.float16, + low_cpu_mem_usage=True, ) if lora_r > 0: