From ae1b33caeee2a9e2b78d391ef136622a09ff87df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erce=20G=C3=BCder?= Date: Fri, 20 Oct 2023 14:09:50 +0300 Subject: [PATCH] add device argument to Processor class --- src/controlnet_aux/processor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/controlnet_aux/processor.py b/src/controlnet_aux/processor.py index 12cb6b0..e3f35d2 100644 --- a/src/controlnet_aux/processor.py +++ b/src/controlnet_aux/processor.py @@ -82,7 +82,7 @@ class Processor: - def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: + def __init__(self, processor_id: str, params: Optional[Dict] = None, device: str = 'cpu') -> None: """Processor that can be used to process images with controlnet aux processors Args: @@ -97,7 +97,7 @@ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}") self.processor_id = processor_id - self.processor = self.load_processor(self.processor_id) + self.processor = self.load_processor(self.processor_id, device) # load default params self.params = MODEL_PARAMS[self.processor_id] @@ -105,7 +105,7 @@ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: if params: self.params.update(params) - def load_processor(self, processor_id: str) -> 'Processor': + def load_processor(self, processor_id: str, device: str = 'cpu') -> 'Processor': """Load controlnet aux processors Args: @@ -119,6 +119,7 @@ def load_processor(self, processor_id: str) -> 'Processor': # check if the proecssor is a checkpoint model if MODELS[processor_id]['checkpoint']: processor = processor.from_pretrained("lllyasviel/Annotators") + processor.to(device) else: processor = processor() return processor