Skip to content

Conversation

@JohannesGaessler
Copy link
Collaborator

This PR adds automation for setting parameters in such a way that maximizes memory utilization when the full model cannot be fit. The short version is that the code first tries reducing the context size and then starts moving weights from device memory to system memory. For MoE models dense weights are prioritized for allocation in device memory since system memory is usually slower than device memory. Example log snippet:

llama_params_fit: projected memory use with initial parameters [MiB]:
llama_params_fit:   - ROCm0 (AMD Radeon Graphics)     :  16304 total,  39959 used,  24341 deficit
llama_params_fit:   - ROCm1 (AMD Radeon RX 6800)      :  16368 total,  42480 used,  26296 deficit
llama_params_fit:   - ROCm2 (AMD Instinct MI60 / MI50):  32752 total,  76200 used,  43626 deficit
llama_params_fit: projected to use 158641 MiB of device memory vs. a total of 65424 MiB
llama_params_fit: cannot fulfill margin of 1024 MiB on all devices, need to use 97337 MiB less in total
llama_params_fit: context size reduced from 65536 to 4096 -> need 13440 MiB less memory
llama_params_fit: with only dense weights in device memory there is a total surplus of 53432 MiB
llama_params_fit: set to use 35 dense-only and 22 full GPU layers in total, projected memory use:
llama_params_fit:   - ROCm0 (AMD Radeon Graphics)     :  0 dense-only layers,  5 full layers,  14373 MiB used,   1244 MiB free
llama_params_fit:   - ROCm1 (AMD Radeon RX 6800)      : 13 dense-only layers,  5 full layers,  14354 MiB used,   1829 MiB free
llama_params_fit:   - ROCm2 (AMD Instinct MI60 / MI50): 22 dense-only layers, 12 full layers,  30917 MiB used,   1656 MiB free

User Interface

  • The llama C API has a new function llama_params_fit that adjusts the provided llama_model_params and llama_context_params in such a way that upon use to create a corresponding llama_model and llama_context the program will not run out of memory.
  • llama_model_params has a new flag no_alloc that is false by default but results in a llama_model and llama_context with only metadata if set to true.
  • New CLI argument --fit [on|off] to control whether parameters should be fit to free device memory, enabled by default. The overall intent is to have optimistic defaults that would require a large amount of resources and to then cut down on the use if insufficient resources are available.
  • New CLI argument --fit-ctx to control the minimum context size that can be set by the code in order to reduce memory use, defaults to 4096.
  • New CLI argument --fit-margin to set the margin in free MiB per device that should be left over after allocation, defaults to 1024 MiB.
  • The default context size is set to 0, meaning models use the maximum context size by default.
  • If the context size is set manually it is not changed.
  • If the number of GPU layers, a tensor split, or tensor buft overrides are set, then the way tensors are allocated is not changed.
  • The log output of the dummy models and contexts is not shown unless the --verbose flag is set.

Implementation Details

  • No actual device memory is being allocated when determining memory limits. Instead the new no_alloc flag is used to create dummy models and contexts from which the optimal parameters can be determined. This makes use of the recently added memory_breakdown methods which have been extended to handle dummy allocations.
  • The overhead in the simplest case where the initial parameters do not need to be changed is ~0.1 s with the creation of a single dummy model and context (determined by how the runtime changes with --fit on vs. --fit off). At most 6 dummy models and contexts will be created by the function when loading a MoE model where only the dense layers fit into memory. Most of the overhead comes I think from loading the vocabulary. Initially I intended to skip loading the vocabulary entirely but that seems to cause issues when then trying to construct the compute graph. I'm not sure how to proceed with this: on the one hand it would be nice to reduce the overhead if possible but on the other hand one could possibly unify the vocab_only and no_alloc flags for a simpler interface.
  • When creating dummy objects the log is temporarily filtered to avoid spamming the console. I made it so that by default only error messages are shown normally because some models produce a large number of warnings that render the prints about memory use effectively unreadable (warnings and below are moved to the debug log). I am concerned that if the program were to crash during the creation of the dummy objects then the default console output would be less useful to determine the issue. Though my impression is that nowadays crashes in llama.cpp itself are relatively rare so maybe it's okay? In any case, we should adjust the issue template to instruct users to always provide a --verbose log.
  • Due to the temporary change in logging the function llama_params_fit is not thread safe. I don't have a good understanding of the current state of thread safety for the llama C API so I would appreciate guidance regarding how much of an issue this is.
  • For the tensor split and the tensor buft overrides one needs to allocate memory and pass a pointer when creating a model via the llama C API. The way I've handled this in llama_params_fit is that the user needs to pass such pointers to the function or else those properties cannot be modified. I think this is preferable over allocating memory in the function itself. I've considered modifying the data pointed at by e.g. model_params::tensor_split directly but given the risk of a segfault I think it's preferable to be explicit with the user having to provide buffers.
  • llama_context now tracks how much memory should be allocated at most for the compute graph over its lifetime (I'm using this to determine projected memory use). On destruction of llama_context the size of the actually allocated buffers is compared to the expectation and a warning is issued if it was exceeded.

Backend Changes

  • The ggml API has been extended with a new function ggml_log_get to retrieve the current state of the logger.
  • The ggml backend API has been extended with new functions which return the amount of memory that would be allocated without actually doing any allocations.
  • The ggml backend scheduler now no longer tries to allocate tensors for which buffer but not data has already been set. This enables creating a dummy buffer and then setting that dummy buffer for the weight and KV cache tensors to prevent them from being allocated for the compute graph (or being considered for allocation when trying to determined how much memory would need to be allocated for the compute graph).

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Oct 18, 2025
@JohannesGaessler JohannesGaessler changed the title llama: automatically fit parameters not set by the user to free device memory llama: automatically set parameters not set by the user in such a way that maximizes GPU utilization Oct 18, 2025
@ggerganov
Copy link
Member

@JohannesGaessler Could you rebase/merge to latest master as there are recent changes related to memory usage on MacOS that I would like to have here?

@ark3
Copy link

ark3 commented Oct 22, 2025

Not sure whether you're ready for feedback on this, but I'm very excited for this feature.

llama-server
      --model ${gguf}/DeepSeek-V3.1-Terminus-UD-Q4_K_XL-00001-of-00008.gguf
      --alias deepseek/deepseek-v3.1-terminus
      --jinja
      -fa on
      --reasoning-budget 0
      --reasoning-format deepseek
      --fit-ctx 131072
      --fit on
      --cache-ram 20000
      --temp 0.6 --top-p 0.95
      --threads 32
      --threads-batch 32

yields

ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes
  Device 1: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes
build: 6796 (653d762dc) with cc (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0 for x86_64-linux-gnu
system info: n_threads = 32, n_threads_batch = 32, total_threads = 128
[...]
llama_params_fit: projected memory use with initial parameters [MiB]:
llama_params_fit:   - CUDA0 (NVIDIA RTX 6000 Ada Generation):  48508 total, 188634 used, 140560 deficit
llama_params_fit:   - CUDA1 (NVIDIA RTX 6000 Ada Generation):  48508 total, 201090 used, 153016 deficit
llama_params_fit: projected to use 389724 MiB of device memory vs. a total of 97017 MiB
llama_params_fit: cannot fulfill margin of 1024 MiB on all devices, need to use 295625 MiB less in total
llama_params_fit: context size reduced from 163840 to 131072 -> need 4148 MiB less memory
llama_params_fit: with only dense weights in device memory there is a total surplus of 62238 MiB
llama_params_fit: set to use 50 dense-only and 12 full GPU layers in total, projected memory use:
llama_params_fit:   - CUDA0 (NVIDIA RTX 6000 Ada Generation): 14 dense-only layers,  6 full layers,  46445 MiB used,   1627 MiB free
llama_params_fit:   - CUDA1 (NVIDIA RTX 6000 Ada Generation): 36 dense-only layers,  6 full layers,  31872 MiB used,  16200 MiB free
common_init_from_params: successfully fit parameters to device memory
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA RTX 6000 Ada Generation) (0000:c1:00.0) - 48073 MiB free
llama_model_load_from_file_impl: using device CUDA1 (NVIDIA RTX 6000 Ada Generation) (0000:e1:00.0) - 48073 MiB free
[...]
load_tensors: offloading 61 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 62/62 layers to GPU
load_tensors:   CPU_Mapped model buffer size =  6074.39 MiB
load_tensors:        CUDA0 model buffer size = 39781.48 MiB
load_tensors:        CUDA1 model buffer size = 38807.08 MiB
[...]
llama_context:  CUDA_Host  output buffer size =     0.49 MiB
llama_kv_cache:      CUDA0 KV buffer size =  5440.00 MiB
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 11152.00 MiB on device 1: cudaMalloc failed: out of memory
alloc_tensor_range: failed to allocate CUDA1 buffer of size 11693719552
llama_init_from_model: failed to initialize the context: failed to allocate buffer for kv cache

followed by exit.

Full logs
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes
  Device 1: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes
build: 6796 (653d762dc) with cc (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0 for x86_64-linux-gnu
system info: n_threads = 32, n_threads_batch = 32, total_threads = 128

system_info: n_threads = 32 (n_threads_batch = 32) / 128 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 

main: binding port with default address family
main: HTTP server is listening, hostname: localhost, port: 19308, http threads: 127
main: loading model
srv    load_model: loading model '/disk1/ai/gguf/DeepSeek-V3.1-Terminus-UD-Q4_K_XL-00001-of-00008.gguf'
llama_params_fit: projected memory use with initial parameters [MiB]:
llama_params_fit:   - CUDA0 (NVIDIA RTX 6000 Ada Generation):  48508 total, 188634 used, 140560 deficit
llama_params_fit:   - CUDA1 (NVIDIA RTX 6000 Ada Generation):  48508 total, 201090 used, 153016 deficit
llama_params_fit: projected to use 389724 MiB of device memory vs. a total of 97017 MiB
llama_params_fit: cannot fulfill margin of 1024 MiB on all devices, need to use 295625 MiB less in total
llama_params_fit: context size reduced from 163840 to 131072 -> need 4148 MiB less memory
llama_params_fit: with only dense weights in device memory there is a total surplus of 62238 MiB
llama_params_fit: set to use 50 dense-only and 12 full GPU layers in total, projected memory use:
llama_params_fit:   - CUDA0 (NVIDIA RTX 6000 Ada Generation): 14 dense-only layers,  6 full layers,  46445 MiB used,   1627 MiB free
llama_params_fit:   - CUDA1 (NVIDIA RTX 6000 Ada Generation): 36 dense-only layers,  6 full layers,  31872 MiB used,  16200 MiB free
common_init_from_params: successfully fit parameters to device memory
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA RTX 6000 Ada Generation) (0000:c1:00.0) - 48073 MiB free
llama_model_load_from_file_impl: using device CUDA1 (NVIDIA RTX 6000 Ada Generation) (0000:e1:00.0) - 48073 MiB free
llama_model_loader: additional 7 GGUFs metadata loaded.
llama_model_loader: loaded meta data with 61 key-value pairs and 1086 tensors from /disk1/ai/gguf/DeepSeek-V3.1-Terminus-UD-Q4_K_XL-00001-of-00008.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Deepseek-V3.1-Terminus
llama_model_loader: - kv   3:                           general.basename str              = Deepseek-V3.1-Terminus
llama_model_loader: - kv   4:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   5:                         general.size_label str              = 256x20B
llama_model_loader: - kv   6:                            general.license str              = mit
llama_model_loader: - kv   7:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv   8:                   general.base_model.count u32              = 1
llama_model_loader: - kv   9:                  general.base_model.0.name str              = DeepSeek V3.1 Terminus
llama_model_loader: - kv  10:          general.base_model.0.organization str              = Deepseek Ai
llama_model_loader: - kv  11:              general.base_model.0.repo_url str              = https://huggingface.co/deepseek-ai/De...
llama_model_loader: - kv  12:                      deepseek2.block_count u32              = 61
llama_model_loader: - kv  13:                   deepseek2.context_length u32              = 163840
llama_model_loader: - kv  14:                 deepseek2.embedding_length u32              = 7168
llama_model_loader: - kv  15:              deepseek2.feed_forward_length u32              = 18432
llama_model_loader: - kv  16:             deepseek2.attention.head_count u32              = 128
llama_model_loader: - kv  17:          deepseek2.attention.head_count_kv u32              = 1
llama_model_loader: - kv  18:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  19: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  20:                deepseek2.expert_used_count u32              = 8
llama_model_loader: - kv  21:        deepseek2.leading_dense_block_count u32              = 3
llama_model_loader: - kv  22:                       deepseek2.vocab_size u32              = 129280
llama_model_loader: - kv  23:            deepseek2.attention.q_lora_rank u32              = 1536
llama_model_loader: - kv  24:           deepseek2.attention.kv_lora_rank u32              = 512
llama_model_loader: - kv  25:             deepseek2.attention.key_length u32              = 576
llama_model_loader: - kv  26:           deepseek2.attention.value_length u32              = 512
llama_model_loader: - kv  27:         deepseek2.attention.key_length_mla u32              = 192
llama_model_loader: - kv  28:       deepseek2.attention.value_length_mla u32              = 128
llama_model_loader: - kv  29:       deepseek2.expert_feed_forward_length u32              = 2048
llama_model_loader: - kv  30:                     deepseek2.expert_count u32              = 256
llama_model_loader: - kv  31:              deepseek2.expert_shared_count u32              = 1
llama_model_loader: - kv  32:             deepseek2.expert_weights_scale f32              = 2.500000
llama_model_loader: - kv  33:              deepseek2.expert_weights_norm bool             = true
llama_model_loader: - kv  34:               deepseek2.expert_gating_func u32              = 2
llama_model_loader: - kv  35:             deepseek2.rope.dimension_count u32              = 64
llama_model_loader: - kv  36:                deepseek2.rope.scaling.type str              = yarn
llama_model_loader: - kv  37:              deepseek2.rope.scaling.factor f32              = 40.000000
llama_model_loader: - kv  38: deepseek2.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  39: deepseek2.rope.scaling.yarn_log_multiplier f32              = 0.100000
llama_model_loader: - kv  40:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  41:                         tokenizer.ggml.pre str              = deepseek-v3
llama_model_loader: - kv  42:                      tokenizer.ggml.tokens arr[str,129280]  = ["<|begin▁of▁sentence|>", "<�...
llama_model_loader: - kv  43:                  tokenizer.ggml.token_type arr[i32,129280]  = [3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  44:                      tokenizer.ggml.merges arr[str,127741]  = ["Ġ t", "Ġ a", "i n", "Ġ Ġ", "h e...
llama_model_loader: - kv  45:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  46:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  47:            tokenizer.ggml.padding_token_id u32              = 2
llama_model_loader: - kv  48:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  49:               tokenizer.ggml.add_sep_token bool             = false
llama_model_loader: - kv  50:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  51:                    tokenizer.chat_template str              = {#- Unsloth template fixes #}{% if no...
llama_model_loader: - kv  52:               general.quantization_version u32              = 2
llama_model_loader: - kv  53:                          general.file_type u32              = 15
llama_model_loader: - kv  54:                      quantize.imatrix.file str              = DeepSeek-V3.1-Terminus-GGUF/imatrix_u...
llama_model_loader: - kv  55:                   quantize.imatrix.dataset str              = unsloth_calibration_DeepSeek-V3.1-Ter...
llama_model_loader: - kv  56:             quantize.imatrix.entries_count u32              = 781
llama_model_loader: - kv  57:              quantize.imatrix.chunks_count u32              = 84
llama_model_loader: - kv  58:                                   split.no u16              = 0
llama_model_loader: - kv  59:                        split.tensors.count i32              = 1086
llama_model_loader: - kv  60:                                split.count u16              = 8
llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q8_0:  198 tensors
llama_model_loader: - type q4_K:  453 tensors
llama_model_loader: - type q5_K:   30 tensors
llama_model_loader: - type q6_K:   44 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 357.89 GiB (4.58 BPW) 
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: printing all EOG tokens:
load:   - 1 ('<|end▁of▁sentence|>')
load: special tokens cache size = 818
load: token to piece cache size = 0.8223 MB
print_info: arch             = deepseek2
print_info: vocab_only       = 0
print_info: no_alloc         = 0
print_info: n_ctx_train      = 163840
print_info: n_embd           = 7168
print_info: n_layer          = 61
print_info: n_head           = 128
print_info: n_head_kv        = 1
print_info: n_rot            = 64
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 576
print_info: n_embd_head_v    = 512
print_info: n_gqa            = 128
print_info: n_embd_k_gqa     = 576
print_info: n_embd_v_gqa     = 512
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 18432
print_info: n_expert         = 256
print_info: n_expert_used    = 8
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = yarn
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 0.025
print_info: n_ctx_orig_yarn  = 4096
print_info: rope_finetuned   = unknown
print_info: model type       = 671B
print_info: model params     = 671.03 B
print_info: general.name     = Deepseek-V3.1-Terminus
print_info: n_layer_dense_lead   = 3
print_info: n_lora_q             = 1536
print_info: n_lora_kv            = 512
print_info: n_embd_head_k_mla    = 192
print_info: n_embd_head_v_mla    = 128
print_info: n_ff_exp             = 2048
print_info: n_expert_shared      = 1
print_info: expert_weights_scale = 2.5
print_info: expert_weights_norm  = 1
print_info: expert_gating_func   = sigmoid
print_info: rope_yarn_log_mul    = 0.1000
print_info: vocab type       = BPE
print_info: n_vocab          = 129280
print_info: n_merges         = 127741
print_info: BOS token        = 0 '<|begin▁of▁sentence|>'
print_info: EOS token        = 1 '<|end▁of▁sentence|>'
print_info: EOT token        = 1 '<|end▁of▁sentence|>'
print_info: PAD token        = 2 '<|▁pad▁|>'
print_info: LF token         = 201 'Ċ'
print_info: FIM PRE token    = 128801 '<|fim▁begin|>'
print_info: FIM SUF token    = 128800 '<|fim▁hole|>'
print_info: FIM MID token    = 128802 '<|fim▁end|>'
print_info: EOG token        = 1 '<|end▁of▁sentence|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 61 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 62/62 layers to GPU
load_tensors:   CPU_Mapped model buffer size =  6074.39 MiB
load_tensors:        CUDA0 model buffer size = 39781.48 MiB
load_tensors:        CUDA1 model buffer size = 38807.08 MiB
..............................................................................srv  log_server_r: request: GET /health ::1 503
...............srv  log_server_r: request: GET /health ::1 503
.......
llama_init_from_model: model default pooling_type is [0], but [-1] was specified
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_per_seq = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = false
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 0.025
llama_context: n_ctx_per_seq (131072) < n_ctx_train (163840) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.49 MiB
llama_kv_cache:      CUDA0 KV buffer size =  5440.00 MiB
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 11152.00 MiB on device 1: cudaMalloc failed: out of memory
alloc_tensor_range: failed to allocate CUDA1 buffer of size 11693719552
llama_init_from_model: failed to initialize the context: failed to allocate buffer for kv cache
common_init_from_params: failed to create context with model '/disk1/ai/gguf/DeepSeek-V3.1-Terminus-UD-Q4_K_XL-00001-of-00008.gguf', try reducing --n-gpu-layers if you're running out of VRAM
srv    load_model: failed to load model, '/disk1/ai/gguf/DeepSeek-V3.1-Terminus-UD-Q4_K_XL-00001-of-00008.gguf'
srv    operator(): operator(): cleaning up before exit...
main: exiting due to model loading error

@ehoogeveen-medweb
Copy link

Out of curiosity, does this supersede #14067?

@JohannesGaessler
Copy link
Collaborator Author

@ark3 does it work without --cache-ram?

@ehoogeveen-medweb yes.

@ark3
Copy link

ark3 commented Oct 22, 2025

does it work without --cache-ram?

No. Same error, down to the number: failed to allocate CUDA1 buffer of size 11693719552

@JohannesGaessler
Copy link
Collaborator Author

I think what's happening is that the projected KV cache size is being calculated incorrectly. Should be fixed by #16746 , I'll push a rebased version after that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants