Skip to content

Commit 7855ac5

Browse files
faaanyhlky
andauthored
[tests] make tests device-agnostic (part 4) (#10508)
* initial comit * fix empty cache * fix one more * fix style * update device functions * update * update * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/controlnet/test_controlnet.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/utils/testing_utils.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/controlnet/test_controlnet.py Co-authored-by: hlky <[email protected]> * with gc.collect * update * make style * check_torch_dependencies * add mps empty cache * add changes * bug fix * enable on xpu * update more cases * revert * revert back * Update test_stable_diffusion_xl.py * Update tests/pipelines/stable_diffusion/test_stable_diffusion.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/stable_diffusion/test_stable_diffusion.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py Co-authored-by: hlky <[email protected]> * Update tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py Co-authored-by: hlky <[email protected]> * Apply suggestions from code review Co-authored-by: hlky <[email protected]> * add test marker --------- Co-authored-by: hlky <[email protected]>
1 parent 30cef6b commit 7855ac5

File tree

66 files changed

+626
-498
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+626
-498
lines changed

tests/lora/test_lora_layers_sd.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@
3333
)
3434
from diffusers.utils.import_utils import is_accelerate_available
3535
from diffusers.utils.testing_utils import (
36+
backend_empty_cache,
3637
load_image,
3738
nightly,
3839
numpy_cosine_similarity_distance,
3940
require_peft_backend,
40-
require_torch_gpu,
41+
require_torch_accelerator,
4142
slow,
4243
torch_device,
4344
)
@@ -101,7 +102,7 @@ def tearDown(self):
101102
# Keeping this test here makes sense because it doesn't look any integration
102103
# (value assertions on logits).
103104
@slow
104-
@require_torch_gpu
105+
@require_torch_accelerator
105106
def test_integration_move_lora_cpu(self):
106107
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
107108
lora_id = "takuma104/lora-test-text-encoder-lora-target"
@@ -158,7 +159,7 @@ def test_integration_move_lora_cpu(self):
158159
self.assertTrue(m.weight.device != torch.device("cpu"))
159160

160161
@slow
161-
@require_torch_gpu
162+
@require_torch_accelerator
162163
def test_integration_move_lora_dora_cpu(self):
163164
from peft import LoraConfig
164165

@@ -209,18 +210,18 @@ def test_integration_move_lora_dora_cpu(self):
209210

210211
@slow
211212
@nightly
212-
@require_torch_gpu
213+
@require_torch_accelerator
213214
@require_peft_backend
214215
class LoraIntegrationTests(unittest.TestCase):
215216
def setUp(self):
216217
super().setUp()
217218
gc.collect()
218-
torch.cuda.empty_cache()
219+
backend_empty_cache(torch_device)
219220

220221
def tearDown(self):
221222
super().tearDown()
222223
gc.collect()
223-
torch.cuda.empty_cache()
224+
backend_empty_cache(torch_device)
224225

225226
def test_integration_logits_with_scale(self):
226227
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -378,7 +379,7 @@ def test_a1111_with_model_cpu_offload(self):
378379
generator = torch.Generator().manual_seed(0)
379380

380381
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
381-
pipe.enable_model_cpu_offload()
382+
pipe.enable_model_cpu_offload(device=torch_device)
382383
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
383384
lora_filename = "light_and_shadow.safetensors"
384385
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@@ -400,7 +401,7 @@ def test_a1111_with_sequential_cpu_offload(self):
400401
generator = torch.Generator().manual_seed(0)
401402

402403
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
403-
pipe.enable_sequential_cpu_offload()
404+
pipe.enable_sequential_cpu_offload(device=torch_device)
404405
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
405406
lora_filename = "light_and_shadow.safetensors"
406407
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@@ -656,7 +657,7 @@ def test_sd_load_civitai_empty_network_alpha(self):
656657
See: https://github.com/huggingface/diffusers/issues/5606
657658
"""
658659
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
659-
pipeline.enable_sequential_cpu_offload()
660+
pipeline.enable_sequential_cpu_offload(device=torch_device)
660661
civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors")
661662
pipeline.load_lora_weights(civitai_path, adapter_name="ahri")
662663

tests/lora/test_lora_layers_sd3.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
from diffusers.utils import load_image
3131
from diffusers.utils.import_utils import is_accelerate_available
3232
from diffusers.utils.testing_utils import (
33+
backend_empty_cache,
3334
is_flaky,
3435
nightly,
3536
numpy_cosine_similarity_distance,
3637
require_big_gpu_with_torch_cuda,
3738
require_peft_backend,
38-
require_torch_gpu,
39+
require_torch_accelerator,
3940
torch_device,
4041
)
4142

@@ -93,7 +94,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
9394
def output_shape(self):
9495
return (1, 32, 32, 3)
9596

96-
@require_torch_gpu
97+
@require_torch_accelerator
9798
def test_sd3_lora(self):
9899
"""
99100
Test loading the loras that are saved with the diffusers and peft formats.
@@ -135,7 +136,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
135136

136137

137138
@nightly
138-
@require_torch_gpu
139+
@require_torch_accelerator
139140
@require_peft_backend
140141
@require_big_gpu_with_torch_cuda
141142
@pytest.mark.big_gpu_with_torch_cuda
@@ -146,12 +147,12 @@ class SD3LoraIntegrationTests(unittest.TestCase):
146147
def setUp(self):
147148
super().setUp()
148149
gc.collect()
149-
torch.cuda.empty_cache()
150+
backend_empty_cache(torch_device)
150151

151152
def tearDown(self):
152153
super().tearDown()
153154
gc.collect()
154-
torch.cuda.empty_cache()
155+
backend_empty_cache(torch_device)
155156

156157
def get_inputs(self, device, seed=0):
157158
init_image = load_image(

tests/models/unets/test_models_unet_2d_condition.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from diffusers.utils.import_utils import is_xformers_available
3737
from diffusers.utils.testing_utils import (
3838
backend_empty_cache,
39+
backend_max_memory_allocated,
40+
backend_reset_max_memory_allocated,
41+
backend_reset_peak_memory_stats,
3942
enable_full_determinism,
4043
floats_tensor,
4144
is_peft_available,
@@ -1002,7 +1005,7 @@ def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
10021005
assert loaded_model
10031006
assert new_output.sample.shape == (4, 4, 16, 16)
10041007

1005-
@require_torch_gpu
1008+
@require_torch_accelerator
10061009
def test_load_sharded_checkpoint_from_hub_local(self):
10071010
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10081011
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
@@ -1013,7 +1016,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
10131016
assert loaded_model
10141017
assert new_output.sample.shape == (4, 4, 16, 16)
10151018

1016-
@require_torch_gpu
1019+
@require_torch_accelerator
10171020
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
10181021
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10191022
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
@@ -1024,7 +1027,7 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
10241027
assert loaded_model
10251028
assert new_output.sample.shape == (4, 4, 16, 16)
10261029

1027-
@require_torch_gpu
1030+
@require_torch_accelerator
10281031
@parameterized.expand(
10291032
[
10301033
("hf-internal-testing/unet2d-sharded-dummy", None),
@@ -1039,7 +1042,7 @@ def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
10391042
assert loaded_model
10401043
assert new_output.sample.shape == (4, 4, 16, 16)
10411044

1042-
@require_torch_gpu
1045+
@require_torch_accelerator
10431046
@parameterized.expand(
10441047
[
10451048
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
@@ -1054,7 +1057,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, va
10541057
assert loaded_model
10551058
assert new_output.sample.shape == (4, 4, 16, 16)
10561059

1057-
@require_torch_gpu
1060+
@require_torch_accelerator
10581061
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10591062
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10601063
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
@@ -1064,7 +1067,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10641067
assert loaded_model
10651068
assert new_output.sample.shape == (4, 4, 16, 16)
10661069

1067-
@require_torch_gpu
1070+
@require_torch_accelerator
10681071
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
10691072
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10701073
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
@@ -1164,11 +1167,11 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
11641167

11651168
return model
11661169

1167-
@require_torch_gpu
1170+
@require_torch_accelerator
11681171
def test_set_attention_slice_auto(self):
1169-
torch.cuda.empty_cache()
1170-
torch.cuda.reset_max_memory_allocated()
1171-
torch.cuda.reset_peak_memory_stats()
1172+
backend_empty_cache(torch_device)
1173+
backend_reset_max_memory_allocated(torch_device)
1174+
backend_reset_peak_memory_stats(torch_device)
11721175

11731176
unet = self.get_unet_model()
11741177
unet.set_attention_slice("auto")
@@ -1180,15 +1183,15 @@ def test_set_attention_slice_auto(self):
11801183
with torch.no_grad():
11811184
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
11821185

1183-
mem_bytes = torch.cuda.max_memory_allocated()
1186+
mem_bytes = backend_max_memory_allocated(torch_device)
11841187

11851188
assert mem_bytes < 5 * 10**9
11861189

1187-
@require_torch_gpu
1190+
@require_torch_accelerator
11881191
def test_set_attention_slice_max(self):
1189-
torch.cuda.empty_cache()
1190-
torch.cuda.reset_max_memory_allocated()
1191-
torch.cuda.reset_peak_memory_stats()
1192+
backend_empty_cache(torch_device)
1193+
backend_reset_max_memory_allocated(torch_device)
1194+
backend_reset_peak_memory_stats(torch_device)
11921195

11931196
unet = self.get_unet_model()
11941197
unet.set_attention_slice("max")
@@ -1200,15 +1203,15 @@ def test_set_attention_slice_max(self):
12001203
with torch.no_grad():
12011204
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
12021205

1203-
mem_bytes = torch.cuda.max_memory_allocated()
1206+
mem_bytes = backend_max_memory_allocated(torch_device)
12041207

12051208
assert mem_bytes < 5 * 10**9
12061209

1207-
@require_torch_gpu
1210+
@require_torch_accelerator
12081211
def test_set_attention_slice_int(self):
1209-
torch.cuda.empty_cache()
1210-
torch.cuda.reset_max_memory_allocated()
1211-
torch.cuda.reset_peak_memory_stats()
1212+
backend_empty_cache(torch_device)
1213+
backend_reset_max_memory_allocated(torch_device)
1214+
backend_reset_peak_memory_stats(torch_device)
12121215

12131216
unet = self.get_unet_model()
12141217
unet.set_attention_slice(2)
@@ -1220,15 +1223,15 @@ def test_set_attention_slice_int(self):
12201223
with torch.no_grad():
12211224
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
12221225

1223-
mem_bytes = torch.cuda.max_memory_allocated()
1226+
mem_bytes = backend_max_memory_allocated(torch_device)
12241227

12251228
assert mem_bytes < 5 * 10**9
12261229

1227-
@require_torch_gpu
1230+
@require_torch_accelerator
12281231
def test_set_attention_slice_list(self):
1229-
torch.cuda.empty_cache()
1230-
torch.cuda.reset_max_memory_allocated()
1231-
torch.cuda.reset_peak_memory_stats()
1232+
backend_empty_cache(torch_device)
1233+
backend_reset_max_memory_allocated(torch_device)
1234+
backend_reset_peak_memory_stats(torch_device)
12321235

12331236
# there are 32 sliceable layers
12341237
slice_list = 16 * [2, 3]
@@ -1242,7 +1245,7 @@ def test_set_attention_slice_list(self):
12421245
with torch.no_grad():
12431246
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
12441247

1245-
mem_bytes = torch.cuda.max_memory_allocated()
1248+
mem_bytes = backend_max_memory_allocated(torch_device)
12461249

12471250
assert mem_bytes < 5 * 10**9
12481251

tests/pipelines/controlnet/test_controlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
7979
pipe = StableDiffusionControlNetPipeline.from_pretrained(
8080
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
8181
)
82-
pipe.to("cuda")
82+
pipe.to(torch_device)
8383
pipe.set_progress_bar_config(disable=None)
8484

8585
pipe.unet.to(memory_format=torch.channels_last)

tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from diffusers.utils.testing_utils import (
4141
enable_full_determinism,
4242
floats_tensor,
43-
require_torch_gpu,
43+
require_torch_accelerator,
4444
torch_device,
4545
)
4646

@@ -245,7 +245,7 @@ def test_xformers_attention_forwardGenerator_pass(self):
245245
def test_inference_batch_single_identical(self):
246246
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
247247

248-
@require_torch_gpu
248+
@require_torch_accelerator
249249
def test_stable_diffusion_xl_offloads(self):
250250
pipes = []
251251
components = self.get_dummy_components()
@@ -254,12 +254,12 @@ def test_stable_diffusion_xl_offloads(self):
254254

255255
components = self.get_dummy_components()
256256
sd_pipe = self.pipeline_class(**components)
257-
sd_pipe.enable_model_cpu_offload()
257+
sd_pipe.enable_model_cpu_offload(device=torch_device)
258258
pipes.append(sd_pipe)
259259

260260
components = self.get_dummy_components()
261261
sd_pipe = self.pipeline_class(**components)
262-
sd_pipe.enable_sequential_cpu_offload()
262+
sd_pipe.enable_sequential_cpu_offload(device=torch_device)
263263
pipes.append(sd_pipe)
264264

265265
image_slices = []

tests/pipelines/controlnet/test_controlnet_sdxl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ def test_stable_diffusion_xl_offloads(self):
223223

224224
components = self.get_dummy_components()
225225
sd_pipe = self.pipeline_class(**components)
226-
sd_pipe.enable_model_cpu_offload()
226+
sd_pipe.enable_model_cpu_offload(device=torch_device)
227227
pipes.append(sd_pipe)
228228

229229
components = self.get_dummy_components()
230230
sd_pipe = self.pipeline_class(**components)
231-
sd_pipe.enable_sequential_cpu_offload()
231+
sd_pipe.enable_sequential_cpu_offload(device=torch_device)
232232
pipes.append(sd_pipe)
233233

234234
image_slices = []

tests/pipelines/controlnet_flux/test_controlnet_flux.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from diffusers.models import FluxControlNetModel
3232
from diffusers.utils import load_image
3333
from diffusers.utils.testing_utils import (
34+
backend_empty_cache,
3435
enable_full_determinism,
3536
nightly,
3637
numpy_cosine_similarity_distance,
@@ -217,12 +218,12 @@ class FluxControlNetPipelineSlowTests(unittest.TestCase):
217218
def setUp(self):
218219
super().setUp()
219220
gc.collect()
220-
torch.cuda.empty_cache()
221+
backend_empty_cache(torch_device)
221222

222223
def tearDown(self):
223224
super().tearDown()
224225
gc.collect()
225-
torch.cuda.empty_cache()
226+
backend_empty_cache(torch_device)
226227

227228
def test_canny(self):
228229
controlnet = FluxControlNetModel.from_pretrained(

tests/pipelines/controlnet_sd3/test_controlnet_sd3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def test_canny(self):
239239
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
240240
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
241241
)
242-
pipe.enable_model_cpu_offload()
242+
pipe.enable_model_cpu_offload(device=torch_device)
243243
pipe.set_progress_bar_config(disable=None)
244244

245245
generator = torch.Generator(device="cpu").manual_seed(0)

tests/pipelines/flux/test_pipeline_flux.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
1111
from diffusers.utils.testing_utils import (
12+
backend_empty_cache,
1213
nightly,
1314
numpy_cosine_similarity_distance,
1415
require_big_gpu_with_torch_cuda,
@@ -212,12 +213,12 @@ class FluxPipelineSlowTests(unittest.TestCase):
212213
def setUp(self):
213214
super().setUp()
214215
gc.collect()
215-
torch.cuda.empty_cache()
216+
backend_empty_cache(torch_device)
216217

217218
def tearDown(self):
218219
super().tearDown()
219220
gc.collect()
220-
torch.cuda.empty_cache()
221+
backend_empty_cache(torch_device)
221222

222223
def get_inputs(self, device, seed=0):
223224
generator = torch.Generator(device="cpu").manual_seed(seed)

0 commit comments

Comments
 (0)