36
36
from diffusers .utils .import_utils import is_xformers_available
37
37
from diffusers .utils .testing_utils import (
38
38
backend_empty_cache ,
39
+ backend_max_memory_allocated ,
40
+ backend_reset_max_memory_allocated ,
41
+ backend_reset_peak_memory_stats ,
39
42
enable_full_determinism ,
40
43
floats_tensor ,
41
44
is_peft_available ,
@@ -1002,7 +1005,7 @@ def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
1002
1005
assert loaded_model
1003
1006
assert new_output .sample .shape == (4 , 4 , 16 , 16 )
1004
1007
1005
- @require_torch_gpu
1008
+ @require_torch_accelerator
1006
1009
def test_load_sharded_checkpoint_from_hub_local (self ):
1007
1010
_ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1008
1011
ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy" )
@@ -1013,7 +1016,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
1013
1016
assert loaded_model
1014
1017
assert new_output .sample .shape == (4 , 4 , 16 , 16 )
1015
1018
1016
- @require_torch_gpu
1019
+ @require_torch_accelerator
1017
1020
def test_load_sharded_checkpoint_from_hub_local_subfolder (self ):
1018
1021
_ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1019
1022
ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy-subfolder" )
@@ -1024,7 +1027,7 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
1024
1027
assert loaded_model
1025
1028
assert new_output .sample .shape == (4 , 4 , 16 , 16 )
1026
1029
1027
- @require_torch_gpu
1030
+ @require_torch_accelerator
1028
1031
@parameterized .expand (
1029
1032
[
1030
1033
("hf-internal-testing/unet2d-sharded-dummy" , None ),
@@ -1039,7 +1042,7 @@ def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
1039
1042
assert loaded_model
1040
1043
assert new_output .sample .shape == (4 , 4 , 16 , 16 )
1041
1044
1042
- @require_torch_gpu
1045
+ @require_torch_accelerator
1043
1046
@parameterized .expand (
1044
1047
[
1045
1048
("hf-internal-testing/unet2d-sharded-dummy-subfolder" , None ),
@@ -1054,7 +1057,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, va
1054
1057
assert loaded_model
1055
1058
assert new_output .sample .shape == (4 , 4 , 16 , 16 )
1056
1059
1057
- @require_torch_gpu
1060
+ @require_torch_accelerator
1058
1061
def test_load_sharded_checkpoint_device_map_from_hub_local (self ):
1059
1062
_ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1060
1063
ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy" )
@@ -1064,7 +1067,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
1064
1067
assert loaded_model
1065
1068
assert new_output .sample .shape == (4 , 4 , 16 , 16 )
1066
1069
1067
- @require_torch_gpu
1070
+ @require_torch_accelerator
1068
1071
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder (self ):
1069
1072
_ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1070
1073
ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy-subfolder" )
@@ -1164,11 +1167,11 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
1164
1167
1165
1168
return model
1166
1169
1167
- @require_torch_gpu
1170
+ @require_torch_accelerator
1168
1171
def test_set_attention_slice_auto (self ):
1169
- torch . cuda . empty_cache ( )
1170
- torch . cuda . reset_max_memory_allocated ( )
1171
- torch . cuda . reset_peak_memory_stats ( )
1172
+ backend_empty_cache ( torch_device )
1173
+ backend_reset_max_memory_allocated ( torch_device )
1174
+ backend_reset_peak_memory_stats ( torch_device )
1172
1175
1173
1176
unet = self .get_unet_model ()
1174
1177
unet .set_attention_slice ("auto" )
@@ -1180,15 +1183,15 @@ def test_set_attention_slice_auto(self):
1180
1183
with torch .no_grad ():
1181
1184
_ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
1182
1185
1183
- mem_bytes = torch . cuda . max_memory_allocated ( )
1186
+ mem_bytes = backend_max_memory_allocated ( torch_device )
1184
1187
1185
1188
assert mem_bytes < 5 * 10 ** 9
1186
1189
1187
- @require_torch_gpu
1190
+ @require_torch_accelerator
1188
1191
def test_set_attention_slice_max (self ):
1189
- torch . cuda . empty_cache ( )
1190
- torch . cuda . reset_max_memory_allocated ( )
1191
- torch . cuda . reset_peak_memory_stats ( )
1192
+ backend_empty_cache ( torch_device )
1193
+ backend_reset_max_memory_allocated ( torch_device )
1194
+ backend_reset_peak_memory_stats ( torch_device )
1192
1195
1193
1196
unet = self .get_unet_model ()
1194
1197
unet .set_attention_slice ("max" )
@@ -1200,15 +1203,15 @@ def test_set_attention_slice_max(self):
1200
1203
with torch .no_grad ():
1201
1204
_ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
1202
1205
1203
- mem_bytes = torch . cuda . max_memory_allocated ( )
1206
+ mem_bytes = backend_max_memory_allocated ( torch_device )
1204
1207
1205
1208
assert mem_bytes < 5 * 10 ** 9
1206
1209
1207
- @require_torch_gpu
1210
+ @require_torch_accelerator
1208
1211
def test_set_attention_slice_int (self ):
1209
- torch . cuda . empty_cache ( )
1210
- torch . cuda . reset_max_memory_allocated ( )
1211
- torch . cuda . reset_peak_memory_stats ( )
1212
+ backend_empty_cache ( torch_device )
1213
+ backend_reset_max_memory_allocated ( torch_device )
1214
+ backend_reset_peak_memory_stats ( torch_device )
1212
1215
1213
1216
unet = self .get_unet_model ()
1214
1217
unet .set_attention_slice (2 )
@@ -1220,15 +1223,15 @@ def test_set_attention_slice_int(self):
1220
1223
with torch .no_grad ():
1221
1224
_ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
1222
1225
1223
- mem_bytes = torch . cuda . max_memory_allocated ( )
1226
+ mem_bytes = backend_max_memory_allocated ( torch_device )
1224
1227
1225
1228
assert mem_bytes < 5 * 10 ** 9
1226
1229
1227
- @require_torch_gpu
1230
+ @require_torch_accelerator
1228
1231
def test_set_attention_slice_list (self ):
1229
- torch . cuda . empty_cache ( )
1230
- torch . cuda . reset_max_memory_allocated ( )
1231
- torch . cuda . reset_peak_memory_stats ( )
1232
+ backend_empty_cache ( torch_device )
1233
+ backend_reset_max_memory_allocated ( torch_device )
1234
+ backend_reset_peak_memory_stats ( torch_device )
1232
1235
1233
1236
# there are 32 sliceable layers
1234
1237
slice_list = 16 * [2 , 3 ]
@@ -1242,7 +1245,7 @@ def test_set_attention_slice_list(self):
1242
1245
with torch .no_grad ():
1243
1246
_ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
1244
1247
1245
- mem_bytes = torch . cuda . max_memory_allocated ( )
1248
+ mem_bytes = backend_max_memory_allocated ( torch_device )
1246
1249
1247
1250
assert mem_bytes < 5 * 10 ** 9
1248
1251
0 commit comments