diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index be398fa1295..e9cf889daad 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -2,6 +2,8 @@ import torch +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor + class CachedModelOnlyFullLoad: """A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device. @@ -76,7 +78,15 @@ def full_load_to_vram(self) -> int: for k, v in self._cpu_state_dict.items(): new_state_dict[k] = v.to(self._compute_device, copy=True) self._model.load_state_dict(new_state_dict, assign=True) - self._model.to(self._compute_device) + + check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight") + if isinstance(check_for_gguf, GGMLTensor): + old_value = torch.__future__.get_overwrite_module_params_on_conversion() + torch.__future__.set_overwrite_module_params_on_conversion(True) + self._model.to(self._compute_device) + torch.__future__.set_overwrite_module_params_on_conversion(old_value) + else: + self._model.to(self._compute_device) self._is_in_vram = True return self._total_bytes @@ -92,7 +102,15 @@ def full_unload_from_vram(self) -> int: if self._cpu_state_dict is not None: self._model.load_state_dict(self._cpu_state_dict, assign=True) - self._model.to(self._offload_device) + + check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight") + if isinstance(check_for_gguf, GGMLTensor): + old_value = torch.__future__.get_overwrite_module_params_on_conversion() + torch.__future__.set_overwrite_module_params_on_conversion(True) + self._model.to(self._offload_device) + torch.__future__.set_overwrite_module_params_on_conversion(old_value) + else: + self._model.to(self._offload_device) self._is_in_vram = False return self._total_bytes