From fa69704dcaa5cda1725557502612fcdb3d93753c Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 09:21:43 +0100 Subject: [PATCH 01/12] Add to overload for GGMLTensor, so calling to on the model moves the quantized data as well --- invokeai/backend/quantization/gguf/ggml_tensor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index d48948dcfa9..631709b6e37 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -119,6 +119,13 @@ def size(self, dim: int | None = None): return self.tensor_shape[dim] return self.tensor_shape + @overload + def to(self, *args, **kwargs) -> torch.Tensor: ... + + def to(self, *args, **kwargs): + self.quantized_data = self.quantized_data.to(*args, **kwargs) + return self + @property def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. """The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" From 40f5614a38ed08c0eceb7544f2d08f96347fff1e Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 11:05:20 +0100 Subject: [PATCH 02/12] raise exected exception when attempting to change dtype --- invokeai/backend/quantization/gguf/ggml_tensor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 631709b6e37..208d0f396b8 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -123,6 +123,12 @@ def size(self, dim: int | None = None): def to(self, *args, **kwargs) -> torch.Tensor: ... def to(self, *args, **kwargs): + for func_arg in args: + if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") + if 'dtype' in kwargs.keys(): + if kwargs['dtype'] != self.quantized_data.dtype: + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly." self.quantized_data = self.quantized_data.to(*args, **kwargs) return self From 1abde91ead81b5a1349e2e34f622a61046e7ad82 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 11:08:58 +0100 Subject: [PATCH 03/12] fix missing bracket --- invokeai/backend/quantization/gguf/ggml_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 208d0f396b8..23b7f1ab86f 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -128,7 +128,7 @@ def to(self, *args, **kwargs): raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") if 'dtype' in kwargs.keys(): if kwargs['dtype'] != self.quantized_data.dtype: - raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly." + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") self.quantized_data = self.quantized_data.to(*args, **kwargs) return self From 3d4ea85642ea0e8a7b286ec09a204275c564bc71 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 11:16:08 +0100 Subject: [PATCH 04/12] fix picky ruff issue --- invokeai/backend/quantization/gguf/ggml_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 23b7f1ab86f..405abbc008b 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -126,8 +126,8 @@ def to(self, *args, **kwargs): for func_arg in args: if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") - if 'dtype' in kwargs.keys(): - if kwargs['dtype'] != self.quantized_data.dtype: + if "dtype" in kwargs.keys(): + if kwargs["dtype"] != self.quantized_data.dtype: raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") self.quantized_data = self.quantized_data.to(*args, **kwargs) return self From 71b9d1049d7ac539455a8111698e3919a5593b32 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 29 Apr 2025 11:03:31 +0100 Subject: [PATCH 05/12] revert to overload due to failing tests, use Torch futures instead --- .../cached_model_only_full_load.py | 21 +++++++++++++++++-- .../backend/quantization/gguf/ggml_tensor.py | 13 ------------ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index be398fa1295..c63e52a527f 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -1,3 +1,4 @@ +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor from typing import Any import torch @@ -76,7 +77,15 @@ def full_load_to_vram(self) -> int: for k, v in self._cpu_state_dict.items(): new_state_dict[k] = v.to(self._compute_device, copy=True) self._model.load_state_dict(new_state_dict, assign=True) - self._model.to(self._compute_device) + + check_for_gguf = self._model.state_dict().get("img_in.weight") + if isinstance(check_for_gguf, GGMLTensor): + old_value = torch.__future__.get_overwrite_module_params_on_conversion() + torch.__future__.set_overwrite_module_params_on_conversion(True) + self._model.to(self._compute_device) + torch.__future__.set_overwrite_module_params_on_conversion(old_value) + else: + self._model.to(self._compute_device) self._is_in_vram = True return self._total_bytes @@ -92,7 +101,15 @@ def full_unload_from_vram(self) -> int: if self._cpu_state_dict is not None: self._model.load_state_dict(self._cpu_state_dict, assign=True) - self._model.to(self._offload_device) + + check_for_gguf = self._model.state_dict().get("img_in.weight") + if isinstance(check_for_gguf, GGMLTensor): + old_value = torch.__future__.get_overwrite_module_params_on_conversion() + torch.__future__.set_overwrite_module_params_on_conversion(True) + self._model.to(self._compute_device) + torch.__future__.set_overwrite_module_params_on_conversion(old_value) + else: + self._model.to(self._compute_device) self._is_in_vram = False return self._total_bytes diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 405abbc008b..d48948dcfa9 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -119,19 +119,6 @@ def size(self, dim: int | None = None): return self.tensor_shape[dim] return self.tensor_shape - @overload - def to(self, *args, **kwargs) -> torch.Tensor: ... - - def to(self, *args, **kwargs): - for func_arg in args: - if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: - raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") - if "dtype" in kwargs.keys(): - if kwargs["dtype"] != self.quantized_data.dtype: - raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") - self.quantized_data = self.quantized_data.to(*args, **kwargs) - return self - @property def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. """The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" From 5fbc412d1670056b823a6df84312b36d2e65c3e0 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 29 Apr 2025 11:24:43 +0100 Subject: [PATCH 06/12] fix offload device --- .../model_cache/cached_model/cached_model_only_full_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index c63e52a527f..ddea254599f 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -106,10 +106,10 @@ def full_unload_from_vram(self) -> int: if isinstance(check_for_gguf, GGMLTensor): old_value = torch.__future__.get_overwrite_module_params_on_conversion() torch.__future__.set_overwrite_module_params_on_conversion(True) - self._model.to(self._compute_device) + self._model.to(self._offload_device) torch.__future__.set_overwrite_module_params_on_conversion(old_value) else: - self._model.to(self._compute_device) + self._model.to(self._offload_device) self._is_in_vram = False return self._total_bytes From 597f7c1a0240f7d24a2759a2c77cf647d9d51696 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 29 Apr 2025 11:43:28 +0100 Subject: [PATCH 07/12] add check for state_dict, required to load TI's --- .../model_cache/cached_model/cached_model_only_full_load.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index ddea254599f..095e8281958 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -78,7 +78,8 @@ def full_load_to_vram(self) -> int: new_state_dict[k] = v.to(self._compute_device, copy=True) self._model.load_state_dict(new_state_dict, assign=True) - check_for_gguf = self._model.state_dict().get("img_in.weight") + + check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight") if isinstance(check_for_gguf, GGMLTensor): old_value = torch.__future__.get_overwrite_module_params_on_conversion() torch.__future__.set_overwrite_module_params_on_conversion(True) @@ -102,7 +103,7 @@ def full_unload_from_vram(self) -> int: if self._cpu_state_dict is not None: self._model.load_state_dict(self._cpu_state_dict, assign=True) - check_for_gguf = self._model.state_dict().get("img_in.weight") + check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight") if isinstance(check_for_gguf, GGMLTensor): old_value = torch.__future__.get_overwrite_module_params_on_conversion() torch.__future__.set_overwrite_module_params_on_conversion(True) From 654038ccabfd6222fc3bb894a9cf74f554b4ade7 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 09:21:43 +0100 Subject: [PATCH 08/12] Add to overload for GGMLTensor, so calling to on the model moves the quantized data as well --- invokeai/backend/quantization/gguf/ggml_tensor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index d48948dcfa9..631709b6e37 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -119,6 +119,13 @@ def size(self, dim: int | None = None): return self.tensor_shape[dim] return self.tensor_shape + @overload + def to(self, *args, **kwargs) -> torch.Tensor: ... + + def to(self, *args, **kwargs): + self.quantized_data = self.quantized_data.to(*args, **kwargs) + return self + @property def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. """The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" From 3f14b60e13aad7f35b82e6edd055e0ca192e1f31 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 11:05:20 +0100 Subject: [PATCH 09/12] raise exected exception when attempting to change dtype --- invokeai/backend/quantization/gguf/ggml_tensor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 631709b6e37..208d0f396b8 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -123,6 +123,12 @@ def size(self, dim: int | None = None): def to(self, *args, **kwargs) -> torch.Tensor: ... def to(self, *args, **kwargs): + for func_arg in args: + if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") + if 'dtype' in kwargs.keys(): + if kwargs['dtype'] != self.quantized_data.dtype: + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly." self.quantized_data = self.quantized_data.to(*args, **kwargs) return self From c20e52aec4cba97bcab996d0d95e7719a685fe1f Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 11:08:58 +0100 Subject: [PATCH 10/12] fix missing bracket --- invokeai/backend/quantization/gguf/ggml_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 208d0f396b8..23b7f1ab86f 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -128,7 +128,7 @@ def to(self, *args, **kwargs): raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") if 'dtype' in kwargs.keys(): if kwargs['dtype'] != self.quantized_data.dtype: - raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly." + raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") self.quantized_data = self.quantized_data.to(*args, **kwargs) return self From 4e237a26f454fd10a968d65e000db71420f63060 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 22 Apr 2025 11:16:08 +0100 Subject: [PATCH 11/12] fix picky ruff issue --- invokeai/backend/quantization/gguf/ggml_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 23b7f1ab86f..405abbc008b 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -126,8 +126,8 @@ def to(self, *args, **kwargs): for func_arg in args: if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") - if 'dtype' in kwargs.keys(): - if kwargs['dtype'] != self.quantized_data.dtype: + if "dtype" in kwargs.keys(): + if kwargs["dtype"] != self.quantized_data.dtype: raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") self.quantized_data = self.quantized_data.to(*args, **kwargs) return self From 3f34789860259726fe471396f3308a9726025f70 Mon Sep 17 00:00:00 2001 From: David Burnett Date: Tue, 29 Apr 2025 12:10:57 +0100 Subject: [PATCH 12/12] fix import ordering, remove code I reverted that the resync added back --- .../cached_model/cached_model_only_full_load.py | 8 ++++---- invokeai/backend/quantization/gguf/ggml_tensor.py | 13 ------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index 095e8281958..e9cf889daad 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -1,8 +1,9 @@ -from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor from typing import Any import torch +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor + class CachedModelOnlyFullLoad: """A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device. @@ -78,8 +79,7 @@ def full_load_to_vram(self) -> int: new_state_dict[k] = v.to(self._compute_device, copy=True) self._model.load_state_dict(new_state_dict, assign=True) - - check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight") + check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight") if isinstance(check_for_gguf, GGMLTensor): old_value = torch.__future__.get_overwrite_module_params_on_conversion() torch.__future__.set_overwrite_module_params_on_conversion(True) @@ -103,7 +103,7 @@ def full_unload_from_vram(self) -> int: if self._cpu_state_dict is not None: self._model.load_state_dict(self._cpu_state_dict, assign=True) - check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight") + check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight") if isinstance(check_for_gguf, GGMLTensor): old_value = torch.__future__.get_overwrite_module_params_on_conversion() torch.__future__.set_overwrite_module_params_on_conversion(True) diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 405abbc008b..d48948dcfa9 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -119,19 +119,6 @@ def size(self, dim: int | None = None): return self.tensor_shape[dim] return self.tensor_shape - @overload - def to(self, *args, **kwargs) -> torch.Tensor: ... - - def to(self, *args, **kwargs): - for func_arg in args: - if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype: - raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") - if "dtype" in kwargs.keys(): - if kwargs["dtype"] != self.quantized_data.dtype: - raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") - self.quantized_data = self.quantized_data.to(*args, **kwargs) - return self - @property def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. """The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""