diff --git a/cellpose/contrib/directml.py b/cellpose/contrib/directml.py new file mode 100644 index 00000000..6549953f --- /dev/null +++ b/cellpose/contrib/directml.py @@ -0,0 +1,303 @@ +# Created by https://github.com/Teranis while working on https://github.com/SchmollerLab/Cell_ACDC +# See below for working example + +# Limitations: +# Officially support only up to PyTorch 2.4.1 (should be fine with cellpose) +# Not yet out for python 3.13 +# Probably not the fastest option, but works surprisingly fast and was easy to implement + +# Notes: +# No additional drivers needed, but requires Windows 10/11 and a DirectX 12 compatible GPU +# Install using "pip install torch-directml" + +# Links: +# DirectML: https://microsoft.github.io/DirectML/ +# torch_directml: https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows + +# Examples: +# Entire working example with benchmark and save comparison is at the end of this file + +# Example usage: +# from cellpose import models as models +# model = models.CellposeModel(gpu=True) +# out = model.eval(img) + + + + +### This function has been made obsolete by updates to cellpose.models +# def setup_custom_device(model, device): +# """ +# Forces the model to use a custom device (e.g., DirectML) for inference. +# This is a workaround, and could be handled better in the future. +# (Ideally when all parameters are set initially) + +# Args: +# model (cellpose.CellposeModel|cellpse.Cellpose): Cellpose model. Should work for v2, v3 and custom. +# torch.device (torch.device): Custom device. + +# Returns: +# model (cellpose.CellposeModel|cellpse.Cellpos): Cellpose model with custom device set. +# """ +# model.gpu = True +# model.device = device +# model.mkldnn = False +# if hasattr(model, 'net'): +# model.net.to(device) +# model.net.mkldnn = False +# if hasattr(model, 'cp'): +# model.cp.gpu = True +# model.cp.device = device +# model.cp.mkldnn = False +# if hasattr(model.cp, 'net'): +# model.cp.net.to(device) +# model.cp.net.mkldnn = False +# if hasattr(model, 'sz'): +# model.sz.device = device + +# return model + +# def setup_directML(model): +# """ +# Sets up the Cellpose model to use DirectML for inference. + +# Args: +# model (cellpose.CellposeModel|cellpse.Cellpos): Cellpose model. Should work for v2, v3 and custom. + +# Returns: +# model (cellpose.CellposeModel|cellpse.Cellpos): Cellpose model with DirectML set as the device. +# """ +# print( +# 'Using DirectML GPU for Cellpose model inference' +# ) +# import torch_directml +# directml_device = torch_directml.device() +# model = setup_custom_device(model, directml_device) +# return model + +def fix_sparse_directML(verbose=True): + """DirectML does not support sparse tensors, so we need to fallback to CPU. + This function replaces `torch.sparse_coo_tensor`, `torch._C._sparse_coo_tensor_unsafe`, + `torch._C._sparse_coo_tensor_with_dims_and_tensors`, `torch.sparse.SparseTensor` + with a wrapper that falls back to CPU. + + In the end, this could be handled better in the future. It would probably run faster if we + just manually set the device to CPU, but my goal was to not modify the code too much, + and this runs suprisingly fast. + """ + import torch + import functools + import warnings + + def fallback_to_cpu_on_sparse_error(func, verbose=True): + @functools.wraps(func) # wrapper shinanigans (thanks chatgpt) + def wrapper(*args, **kwargs): + device_arg = kwargs.get('device', None) # get desired device from kwargs + + # Ensure indices are int64 if args[0] looks like indices, + # If errors start to occur that int64 conversion is needed, uncomment this + # (and also consider the block below). + # But be aware! Its probably better to just set the device to cpu in that + # particular case... + # for both performance and compatibility + # if len(args) >= 1 and isinstance(args[0], torch.Tensor): + # if args[0].dtype != torch.int64: + # args = (args[0].to(dtype=torch.int64),) + args[1:] + + try: # try to perform the operation and move to dml if possible + result = func(*args, **kwargs) # run function with current args and kwargs + if device_arg is not None and str(device_arg).lower() == "dml": + try: # try to move result to dml + result.to("dml") + except RuntimeError as e: # moving failed, falling back to cpu + if verbose: + warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}") + kwargs['device'] = torch.device("cpu") + return func(*args, **kwargs) # try again, after setting device to cpu + return result # just return result if all worked well + + except RuntimeError as e: # try and run on dlm, if it fails, fallback to cpu + if "sparse" in str(e).lower() or "not implemented" in str(e).lower(): + if verbose: + warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}") + kwargs['device'] = torch.device("cpu") # if rutime warning caused by sparse tensor, set device to cpu + + # See above comments + # if len(args) >= 1 and isinstance(args[0], torch.Tensor): + # if args[0].dtype != torch.int64: + # args = (args[0].to(dtype=torch.int64),) + args[1:] + try: + res = func(*args, **kwargs) + except RuntimeError as e: # try again with cpu device + if "int64" in str(e).lower(): + if verbose: + warnings.warn(f"need to convert to int64: {e}") + if len(args) >= 1 and isinstance(args[0], torch.Tensor): + if args[0].dtype != torch.int64: + args = (args[0].to(dtype=torch.int64),) + args[1:] + return func(*args, **kwargs) + return res # run function again with cpu device + else: + raise e # catch and other runtime errors + + return wrapper + + # --- Patch Sparse Tensor Constructors --- + + # High-level API + torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error(torch.sparse_coo_tensor, verbose=verbose) + + # Low-level API + if hasattr(torch._C, "_sparse_coo_tensor_unsafe"): + torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error(torch._C._sparse_coo_tensor_unsafe, verbose=verbose) + + if hasattr(torch._C, "_sparse_coo_tensor_with_dims_and_tensors"): + torch._C._sparse_coo_tensor_with_dims_and_tensors = fallback_to_cpu_on_sparse_error( + torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose + ) + + if hasattr(torch.sparse, 'SparseTensor'): + torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error(torch.sparse.SparseTensor, verbose=verbose) + + # suppress warnings + if not verbose: + import warnings + warnings.filterwarnings("once", message="Sparse op failed on DirectML*") + +if __name__ == "__main__": + import time + import numpy as np + import tifffile + import os + + ### Working example with benchmark and save comparison + def _load_data(path, prepare): + """ + Load and prepare data for Cellpose model. + Args: + path (str): Path to the image data. + prepare (bool): Whether to prepare the data for Cellpose model. + Returns: + imgs_list (list): List of images prepared for Cellpose model. + """ + + # load data + imgs = tifffile.imread(path) # read images using tifffile + print(imgs.shape) + if prepare: + imgs_list = [] + for img in imgs: # convert to list of images + img_min = img.min() + img_max = img.max() + img = img.astype(np.float32) + img -= img_min + if img_max > img_min + 1e-3: + img /= (img_max - img_min) + img *= 255 + + img = img.astype(np.float32) + imgs_list.append(img) # add image to list + + return imgs_list + else: + return imgs + + def _compare_data(savepaths): + """ + Compare data from different save paths to check for consistency. + Args: + savepaths (list): List of paths to the saved data. + """ + outs = dict() + for savepath in savepaths: + if not os.path.exists(savepath): + continue + out = np.load(savepath) + out = out[out.files[0]] + outs[savepath] = out + + total_size = out.shape[1] * out.shape[2] + last_out = None + for savepath, out in outs.items(): + file_name = os.path.basename(savepath) + mismatch = False + if last_out is None: + last_out = out + last_file_name = file_name + continue + if out.shape != last_out.shape: + print(f"Shape mismatch for {file_name} vs {last_file_name}: {out.shape} vs {last_out.shape}") + continue + + for frame in range(out.shape[0]): + seg_difference = np.nonzero(out[frame] - last_out[frame]) + perc_diff = len(seg_difference[0]) / total_size + if perc_diff > 0.01: + print(f"Frame {frame} mismatch for {file_name} vs {last_file_name} with {perc_diff:.2%} difference") + mismatch = True + + if not mismatch: + print(f"All frames match for {file_name} vs {last_file_name}") + + + # you need two environment for benchmarking: One with DirectML and one with CUDA. + path = r'path\to\your\data.tif' # path to your data + # pretrained_model = r'path\to\your\model' # path to your pretrained model + pretrained_model = "cpsam" # "cyto3" # for pretrained models + gpu = True # set to True if you want to use GPU + # if False, CPU will be used + just_compare_data = False # set to True if you want to compare data and exit + + # load and prepare images + imgs = _load_data(path, prepare=True) + imgs = imgs[:10] # cut data so we can test it faster + + # save paths for different methods (Don't change order!) + savepaths = [ + path.replace('.tif', '_segm_directml.npz'), + path.replace('.tif', '_segm_GPU.npz'), + path.replace('.tif', '_segm_CPU.npz') + ] + + # for data comparison + if just_compare_data: + _compare_data(savepaths) + exit() + + # init model + from cellpose import models, io + io.logger_setup() + model = models.CellposeModel( + pretrained_model=pretrained_model, gpu=gpu, + ) + + # run model, benchmark + print("Running model...") + start = time.perf_counter() + pref_count_last = time.perf_counter() + times = [] + out_list = [] + for img in imgs: # process each image + out = model.eval(img)[0] # here goes the eval + out_list.append(out) + time_taken = time.perf_counter() - pref_count_last + times.append(time_taken) + print(f'processed image in {time_taken:.2f} seconds') + pref_count_last = time.perf_counter() + end = time.perf_counter() + print(f"Time taken: {end - start:.2f} seconds") + print(f"Average time per image: {np.mean(times):.2f} seconds") + + uses_directml = model.device.type == 'privateuseone' + # save data + if uses_directml: + print("DirectML inference completed.") + savepath = savepaths[0] + elif gpu: + print("GPU inference completed.") + savepath = savepaths[1] + else: + print("CPU inference completed.") + savepath = savepaths[2] + + np.savez_compressed(savepath, out_list=out_list) \ No newline at end of file diff --git a/cellpose/core.py b/cellpose/core.py index 0505f464..ed58949e 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -36,7 +36,7 @@ def use_gpu(gpu_number=0, use_torch=True): def _use_gpu_torch(gpu_number=0): """ - Checks if CUDA or MPS is available and working with PyTorch. + Checks if CUDA or MPS or DirectML is available and working with PyTorch. Args: gpu_number (int): The GPU device number to use (default is 0). @@ -57,13 +57,22 @@ def _use_gpu_torch(gpu_number=0): core_logger.info('** TORCH MPS version installed and working. **') return True except: - core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.') + pass + + try: + import torch_directml + device = torch_directml.device() + _ = torch.zeros((1,1)).to(device) + core_logger.info('** TORCH DIRECTML version installed and working. **') + return True + except: + core_logger.info('Neither TORCH CUDA, MPS nor DirectML version installed/working.') return False def assign_device(use_torch=True, gpu=False, device=0): """ - Assigns the device (CPU or GPU or mps) to be used for computation. + Assigns the device (CPU or GPU or mps or DirectML) to be used for computation. Args: use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True. @@ -78,6 +87,7 @@ def assign_device(use_torch=True, gpu=False, device=0): if device != "mps" or not(gpu and torch.backends.mps.is_available()): device = int(device) if gpu and use_gpu(use_torch=True): + gpu = False try: if torch.cuda.is_available(): device = torch.device(f'cuda:{device}') @@ -96,6 +106,18 @@ def assign_device(use_torch=True, gpu=False, device=0): except: gpu = False cpu = True + + if not gpu: # dont overwrite device if already set + try: + import torch_directml + if torch_directml.is_available(): + device = torch_directml.device(device) + core_logger.info(">>>> using GPU (DirectML)") + gpu = True + cpu = False + except: + gpu = False + cpu = True else: device = torch.device('cpu') core_logger.info('>>>> using CPU') diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index 39fad7dd..2fbddeb2 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -415,6 +415,9 @@ def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu" masks (int, 2D or 3D array): Masks with inconsistent flow masks removed, 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. """ + if device.type == "privateuseone": + device=torch.device("cpu") + device0 = device if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"): diff --git a/cellpose/models.py b/cellpose/models.py index 701d6380..29b7779a 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -94,7 +94,7 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None, Initialize the CellposeModel. Parameters: - gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available. + gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available, first for CUDA, then for MPS, then for DirectML. pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded. model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo). diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value. @@ -119,9 +119,20 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None, elif torch.backends.mps.is_available(): device_gpu = self.device.type == "mps" else: - device_gpu = False + try: + import torch_directml + if torch_directml.is_available(): + device_gpu = self.device.type == "privateuseone" + else: + device_gpu = False + except ImportError: + device_gpu = False self.gpu = device_gpu + if self.device.type == "privateuseone": # fix spare tensors for DirectML + from .contrib.directml import fix_sparse_directML + fix_sparse_directML() + if pretrained_model is None: raise ValueError("Must specify a pretrained model, training from scratch is not implemented") @@ -143,14 +154,21 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None, dtype = torch.bfloat16 if use_bfloat16 else torch.float32 self.net = Transformer(dtype=dtype).to(self.device) + load_device = self.device + if self.device.type == "privateuseone": # for some reason, loading on privateuseone device does not work + load_device = torch.device("cpu") + if os.path.exists(self.pretrained_model): models_logger.info(f">>>> loading model {self.pretrained_model}") - self.net.load_model(self.pretrained_model, device=self.device) + self.net.load_model(self.pretrained_model, device=load_device) else: if os.path.split(self.pretrained_model)[-1] != 'cpsam': raise FileNotFoundError('model file not recognized') cache_CPSAM_model_path() - self.net.load_model(self.pretrained_model, device=self.device) + self.net.load_model(self.pretrained_model, device=load_device) + + if self.device.type == "privateuseone": + self.net.to(self.device) def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, diff --git a/docs/installation.rst b/docs/installation.rst index 066b09ee..7412ecb6 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -56,12 +56,34 @@ Be warned that the ROCm project is significantly less mature than CUDA, and you ROCm is significantly less mature than the CUDA acceleration, and you may run into issues. +DirectML installation +~~~~~~~~~~~~~~~~~~~~~ +DirectML is a cross-platform, high-performance, hardware-accelerated machine learning library for Windows. +It uses the DirectX 12 API to provide a common interface for GPU-accelerated machine learning across +different hardware vendors (AMD, Intel, NVIDIA), although for NVIDIA GPUs, CUDA is still recommended. + +To install DirectML, first make sure you have the latest version of Windows 10 or 11 installed, and +update your GPU drivers. Also, DirectML only supports python 3.12 or earlier. +Then, install ``torch-directml`` using ``pip install torch-directml``. + +Code example: +:: + + from cellpose import models as models + model = models.CellposeModel(gpu=True) + mask = model.eval(image)[0] + + + +For more information on DirectML, see the `DirectML documentation `_. +Please contact `Teranis `_ for any issues with DirectML implementation. + Common issues ~~~~~~~~~~~~~~~~~~~~~~~ If you receive an issue with Qt "xcb", you may need to install xcb libraries, e.g.: -:: +:: sudo apt install libxcb-cursor0 sudo apt install libxcb-xinerama0 @@ -91,7 +113,7 @@ If you are having other issues with the graphical interface and QT, see some adv If you have errors related to OpenMP and libiomp5, then try :: - + conda install nomkl If you receive an error associated with **matplotlib**, try upgrading diff --git a/tests/contrib/test_directml.py b/tests/contrib/test_directml.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 00000000..80cb8582 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,83 @@ +from cellpose.core import assign_device +import pytest +import torch +from unittest.mock import patch + + +is_cuda_available = torch.cuda.is_available() + +is_directml_available = False +try: + import torch_directml + is_directml_available = torch_directml.is_available() +except ImportError: + pass + +is_mps_available = torch.backends.mps.is_available() + + + +@pytest.mark.parametrize( + "gpu,disable_cuda,disable_mps", + [ + (True, False, False), + (True, True, False), + (True, False, True), + (True, True, True), + (False, False, False), + (False, True, False), + (False, False, True), + (False, True, True), + ] +) +def test_assign_device(gpu, disable_cuda, disable_mps): + if disable_cuda and disable_mps: + with patch('torch.cuda.is_available', return_value=False), \ + patch('torch.backends.mps.is_available', return_value=False): + assigned_device, gpu = assign_device(gpu=gpu) + + if is_directml_available: + expected_device = torch_directml.device() + else: + expected_device = torch.device('cpu') + + elif disable_cuda and not disable_mps: + with patch('torch.cuda.is_available', return_value=False): + assigned_device, gpu = assign_device(gpu=gpu) + + if is_mps_available: + expected_device = torch.device('mps') + elif is_directml_available: + expected_device = torch_directml.device() + else: + expected_device = torch.device('cpu') + + + elif not disable_cuda and disable_mps: + with patch('torch.backends.mps.is_available', return_value=False): + assigned_device, gpu = assign_device(gpu=gpu) + + if is_cuda_available: + expected_device = torch.device('cuda') + elif is_directml_available: + expected_device = torch_directml.device() + else: + expected_device = torch.device('cpu') + + elif not disable_cuda and not disable_mps: + assigned_device, gpu = assign_device(gpu=gpu) + + if is_cuda_available: + expected_device = torch.device('cuda') + elif is_mps_available: + expected_device = torch.device('mps') + elif is_directml_available: + expected_device = torch_directml.device() + else: + expected_device = torch.device('cpu') + + if not gpu: + expected_device = torch.device('cpu') + + assert assigned_device == expected_device + \ No newline at end of file