Skip to content

Commit 6935ff6

Browse files
committed
Remove direct references to torch.cuda with a layer of indirection (to enable other device types like DirectML and intel xpu); Update tests to allow testing different device types and the non-diffusers backend more easily
1 parent 25f9ce4 commit 6935ff6

31 files changed

+579
-341
lines changed

sdkit/__init__.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
import sys
12
from threading import local
23

4+
if sys.version_info < (3, 9):
5+
# polyfill for callable static methods. required for pytorch-directml
6+
class CallableStaticMethod(staticmethod):
7+
def __call__(self, *args, **kwargs):
8+
return self.__func__(*args, **kwargs)
9+
10+
# Patch the built-in staticmethod with CallableStaticMethod
11+
import builtins
12+
13+
builtins.staticmethod = CallableStaticMethod
14+
315

416
class Context(local):
517
def __init__(self) -> None:
6-
self._device: str = "cuda:0"
18+
self._device: str = ""
19+
self._torch_device = None
720
self._half_precision: bool = True
821
self._vram_usage_level = None
922

@@ -45,6 +58,10 @@ def __init__(self) -> None:
4558
https://github.com/sczhou/CodeFormer/blob/master/LICENSE
4659
"""
4760

61+
from sdkit.utils import get_torch_platform
62+
63+
self.device = get_torch_platform()[0]
64+
4865
# hacky approach, but we need to enforce full precision for some devices
4966
# we also need to force full precision for these devices (haven't implemented this yet):
5067
# (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
@@ -55,12 +72,21 @@ def device(self):
5572
@device.setter
5673
def device(self, d):
5774
self._device = d
58-
if "cuda" not in d:
75+
76+
from sdkit.utils import get_device
77+
78+
if d.split(":")[0] in ("cpu", "mps"):
5979
from sdkit.utils import log
6080

6181
log.info(f"forcing full precision for device: {d}")
6282
self._half_precision = False
6383

84+
self._torch_device = get_device(d)
85+
86+
@property
87+
def torch_device(self):
88+
return self._torch_device
89+
6490
@property
6591
def half_precision(self):
6692
return self._half_precision

sdkit/filter/codeformer/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from sdkit import Context
55
from sdkit.models import load_model, unload_model
6+
from sdkit.utils import empty_cache
67

78
from torchvision.transforms.functional import normalize
89
from threading import Lock
@@ -16,7 +17,7 @@
1617

1718

1819
def inference(context: Context, image, upscale_bg, upscale_faces, upscale_factor, codeformer_fidelity, codeformer_net):
19-
device = torch.device(context.device)
20+
device = context.torch_device
2021
face_helper = FaceRestoreHelper(upscale_factor=upscale_factor, use_parse=True, device=device)
2122
face_helper.clean_all()
2223
face_helper.read_image(image)
@@ -37,7 +38,7 @@ def inference(context: Context, image, upscale_bg, upscale_faces, upscale_factor
3738
output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
3839
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
3940
del output
40-
torch.cuda.empty_cache()
41+
empty_cache()
4142
except RuntimeError as error:
4243
print(f"Failed inference for CodeFormer: {error}")
4344
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -72,7 +73,7 @@ def apply(
7273
if (upscale_background or upscale_faces) and "realesrgan" not in context.models:
7374
raise Exception("realesrgan not loaded in context.models! Required for upscaling in CodeFormer.")
7475

75-
device = torch.device(context.device)
76+
device = context.torch_device
7677
codeformer_net = context.models["codeformer"]
7778

7879
# Convert PIL Image to numpy array and ensure it's in BGR format for OpenCV
@@ -84,7 +85,7 @@ def apply(
8485
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
8586
from facexlib.detection import retinaface
8687

87-
retinaface.device = torch.device(context.device)
88+
retinaface.device = context.torch_device
8889

8990
result = inference(
9091
context, input_img, upscale_background, upscale_faces, upscale_factor, codeformer_fidelity, codeformer_net

sdkit/filter/gfpgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def apply(context: Context, image, **kwargs):
1515
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
1616
from facexlib.detection import retinaface
1717

18-
retinaface.device = torch.device(context.device)
18+
retinaface.device = context.torch_device
1919

2020
image = image.convert("RGB")
2121
image = np.array(image, dtype=np.uint8)[..., ::-1]

sdkit/generate/image_generator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ def generate_images(
5757
try:
5858
images = []
5959

60-
seed_everything(seed)
61-
precision_scope = torch.autocast if context.half_precision else nullcontext
62-
6360
if "stable-diffusion" not in context.models:
6461
raise RuntimeError(
6562
"The model for Stable Diffusion has not been loaded yet! If you've tried to load it, please check the logs above this message for errors (while loading the model)."
@@ -96,7 +93,10 @@ def generate_images(
9693
if "hypernetwork" in context.models:
9794
context.models["hypernetwork"]["hypernetwork_strength"] = hypernetwork_strength
9895

99-
with precision_scope("cuda"):
96+
seed_everything(seed)
97+
precision_scope = torch.autocast if context.half_precision else nullcontext
98+
99+
with precision_scope(context.torch_device.type):
100100
cond, uncond = get_cond_and_uncond(prompt, negative_prompt, num_outputs, model)
101101

102102
generate_fn = txt2img if init_image is None else img2img
@@ -113,7 +113,7 @@ def generate_images(
113113
"callback": callback,
114114
}
115115

116-
with torch.no_grad(), precision_scope("cuda"):
116+
with torch.no_grad(), precision_scope(context.torch_device.type):
117117
for _ in trange(1, desc="Sampling"):
118118
images += generate_fn(common_sampler_params.copy(), **req_args)
119119
gc(context)
@@ -229,10 +229,7 @@ def make_with_diffusers(
229229

230230
model = context.models["stable-diffusion"]
231231
default_pipe = model["default"]
232-
if context.device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
233-
generator = torch.Generator().manual_seed(seed)
234-
else:
235-
generator = torch.Generator(context.device).manual_seed(seed)
232+
generator = torch.Generator(context.torch_device).manual_seed(seed)
236233

237234
is_sd_xl = isinstance(
238235
default_pipe,
@@ -462,7 +459,7 @@ def lora_conv_forward(self, hidden_states, scale=1.0):
462459
if hasattr(operation_to_apply.unet, "_allocate_trt_buffers"):
463460
dtype = torch.float16 if context.half_precision else torch.float32
464461
operation_to_apply.unet._allocate_trt_buffers(
465-
operation_to_apply, context.device, dtype, num_outputs, width, height
462+
operation_to_apply, context.torch_device, dtype, num_outputs, width, height
466463
)
467464

468465
# apply

sdkit/generate/sampler/sampler_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def make_samples(
5050
if sampler_module is None:
5151
raise RuntimeError(f'Unknown sampler "{sampler_name}"!')
5252

53-
noise = make_some_noise(seed, batch_size, shape, context.device)
53+
noise = make_some_noise(seed, batch_size, shape, context.torch_device)
5454

5555
return sampler_module.sample(
5656
context, sampler_name, noise, batch_size, shape, steps, cond, uncond, guidance_scale, callback, **kwargs

sdkit/models/model_loader/codeformer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def load_model(context: Context, **kwargs):
2626
sd = sd["params_ema"]
2727

2828
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=["32", "64", "128", "256"])
29-
model = model.to(context.device)
29+
model = model.to(context.torch_device)
3030

3131
model.load_state_dict(sd)
3232
model.eval()

sdkit/models/model_loader/controlnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def load_controlnet(context, controlnet_path):
1818
import torch
1919
from sdkit.models import get_model_info_from_db
2020
from sdkit.models import models_db
21-
from sdkit.utils import load_tensor_file
21+
from sdkit.utils import load_tensor_file, is_cpu_device
2222

2323
from accelerate import cpu_offload
2424

@@ -76,13 +76,13 @@ def load_controlnet(context, controlnet_path):
7676

7777
# memory optimizations
7878

79-
if context.vram_usage_level == "low" and "cuda" in context.device:
79+
if context.vram_usage_level == "low" and not is_cpu_device(context.torch_device):
8080
controlnet = controlnet.to("cpu", torch.float16 if context.half_precision else torch.float32)
8181

8282
offload_buffers = len(controlnet._parameters) > 0
83-
cpu_offload(controlnet, context.device, offload_buffers=offload_buffers)
83+
cpu_offload(controlnet, context.torch_device, offload_buffers=offload_buffers)
8484
else:
85-
controlnet = controlnet.to(context.device, torch.float16 if context.half_precision else torch.float32)
85+
controlnet = controlnet.to(context.torch_device, torch.float16 if context.half_precision else torch.float32)
8686

8787
controlnet.set_attention_slice(1)
8888

sdkit/models/model_loader/controlnet_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def load_model(context: Context, **kwargs):
4141
model = Processor(model_type)
4242

4343
if hasattr(model.processor, "to"):
44-
model.processor = model.processor.to(context.device)
44+
model.processor = model.processor.to(context.torch_device)
4545

4646
return model
4747

sdkit/models/model_loader/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def attach_hooks(context, components):
105105
from accelerate import cpu_offload
106106

107107
for _, te in components:
108-
cpu_offload(te, context.device, offload_buffers=len(te._parameters) > 0)
108+
cpu_offload(te, context.torch_device, offload_buffers=len(te._parameters) > 0)
109109

110110

111111
def get_embedding(embedding):

sdkit/models/model_loader/gfpgan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def load_model(context: Context, **kwargs):
3030
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
3131
from facexlib.detection import retinaface
3232

33-
retinaface.device = torch.device(context.device)
33+
retinaface.device = context.torch_device
3434

3535
return GFPGANer(
36-
device=torch.device(context.device),
36+
device=context.torch_device,
3737
model_path=model_path,
3838
upscale=1,
3939
arch="clean",

0 commit comments

Comments
 (0)