- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.4k
llama: automatically set parameters not set by the user in such a way that maximizes GPU utilization #16653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
llama: automatically set parameters not set by the user in such a way that maximizes GPU utilization #16653
Conversation
| @JohannesGaessler Could you rebase/merge to latest  | 
| Not sure whether you're ready for feedback on this, but I'm very excited for this feature. yields followed by exit. Full logs | 
| Out of curiosity, does this supersede #14067? | 
| @ark3 does it work without  @ehoogeveen-medweb yes. | 
| 
 No. Same error, down to the number:  | 
653d762    to
    00fb12b      
    Compare
  
    | I think what's happening is that the projected KV cache size is being calculated incorrectly. Should be fixed by #16746 , I'll push a rebased version after that. | 
This PR adds automation for setting parameters in such a way that maximizes memory utilization when the full model cannot be fit. The short version is that the code first tries reducing the context size and then starts moving weights from device memory to system memory. For MoE models dense weights are prioritized for allocation in device memory since system memory is usually slower than device memory. Example log snippet:
User Interface
llama_params_fitthat adjusts the providedllama_model_paramsandllama_context_paramsin such a way that upon use to create a correspondingllama_modelandllama_contextthe program will not run out of memory.llama_model_paramshas a new flagno_allocthat is false by default but results in allama_modelandllama_contextwith only metadata if set to true.--fit [on|off]to control whether parameters should be fit to free device memory, enabled by default. The overall intent is to have optimistic defaults that would require a large amount of resources and to then cut down on the use if insufficient resources are available.--fit-ctxto control the minimum context size that can be set by the code in order to reduce memory use, defaults to 4096.--fit-marginto set the margin in free MiB per device that should be left over after allocation, defaults to 1024 MiB.--verboseflag is set.Implementation Details
no_allocflag is used to create dummy models and contexts from which the optimal parameters can be determined. This makes use of the recently addedmemory_breakdownmethods which have been extended to handle dummy allocations.--fit onvs.--fit off). At most 6 dummy models and contexts will be created by the function when loading a MoE model where only the dense layers fit into memory. Most of the overhead comes I think from loading the vocabulary. Initially I intended to skip loading the vocabulary entirely but that seems to cause issues when then trying to construct the compute graph. I'm not sure how to proceed with this: on the one hand it would be nice to reduce the overhead if possible but on the other hand one could possibly unify thevocab_onlyandno_allocflags for a simpler interface.--verboselog.llama_params_fitis not thread safe. I don't have a good understanding of the current state of thread safety for the llama C API so I would appreciate guidance regarding how much of an issue this is.llama_params_fitis that the user needs to pass such pointers to the function or else those properties cannot be modified. I think this is preferable over allocating memory in the function itself. I've considered modifying the data pointed at by e.g.model_params::tensor_splitdirectly but given the risk of a segfault I think it's preferable to be explicit with the user having to provide buffers.llama_contextnow tracks how much memory should be allocated at most for the compute graph over its lifetime (I'm using this to determine projected memory use). On destruction ofllama_contextthe size of the actually allocated buffers is compared to the expectation and a warning is issued if it was exceeded.Backend Changes
ggml_log_getto retrieve the current state of the logger.bufferbut notdatahas already been set. This enables creating a dummy buffer and then setting that dummy buffer for the weight and KV cache tensors to prevent them from being allocated for the compute graph (or being considered for allocation when trying to determined how much memory would need to be allocated for the compute graph).