Skip to content

Commit de930cb

Browse files
committed
Utility function for half precision bug detection
1 parent 6935ff6 commit de930cb

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

sdkit/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,6 @@
5454
empty_cache,
5555
ipc_collect,
5656
is_cpu_device,
57+
has_half_precision_bug,
5758
)
5859
from .misc_utils import make_sd_context, get_nested_attr

sdkit/utils/device_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import re
12
import platform
23
import subprocess
34
from typing import Union, Tuple, Dict
45

56

7+
NVIDIA_PATTERN = re.compile(r"\b(?:nvidia|geforce|quadro|tesla)\b", re.IGNORECASE)
8+
NVIDIA_HALF_PRECISION_BUG_PATTERN = re.compile(r"\b(?:tesla k40m|16\d\d|t\d{2,})\b", re.IGNORECASE)
9+
AMD_HALF_PRECISION_BUG_PATTERN = re.compile(r"\b(?:navi 1\d)\b", re.IGNORECASE)
10+
11+
612
def has_amd_gpu():
713
os_name = platform.system()
814
try:
@@ -146,3 +152,10 @@ def is_cpu_device(device) -> bool: # used for cpu offloading etc
146152
"Expects a torch.device as the argument"
147153

148154
return device.type in ("cpu", "mps")
155+
156+
157+
def has_half_precision_bug(device_name) -> bool:
158+
"Check whether the given device requires full precision for generating images due to a firmware bug"
159+
if NVIDIA_PATTERN.search(device_name):
160+
return NVIDIA_HALF_PRECISION_BUG_PATTERN.search(device_name) is not None
161+
return AMD_HALF_PRECISION_BUG_PATTERN.search(device_name) is not None

0 commit comments

Comments
 (0)