diff --git a/src/controlnet_aux/processor.py b/src/controlnet_aux/processor.py index 12cb6b0..e3f35d2 100644 --- a/src/controlnet_aux/processor.py +++ b/src/controlnet_aux/processor.py @@ -82,7 +82,7 @@ class Processor: - def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: + def __init__(self, processor_id: str, params: Optional[Dict] = None, device: str = 'cpu') -> None: """Processor that can be used to process images with controlnet aux processors Args: @@ -97,7 +97,7 @@ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}") self.processor_id = processor_id - self.processor = self.load_processor(self.processor_id) + self.processor = self.load_processor(self.processor_id, device) # load default params self.params = MODEL_PARAMS[self.processor_id] @@ -105,7 +105,7 @@ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: if params: self.params.update(params) - def load_processor(self, processor_id: str) -> 'Processor': + def load_processor(self, processor_id: str, device: str = 'cpu') -> 'Processor': """Load controlnet aux processors Args: @@ -119,6 +119,7 @@ def load_processor(self, processor_id: str) -> 'Processor': # check if the proecssor is a checkpoint model if MODELS[processor_id]['checkpoint']: processor = processor.from_pretrained("lllyasviel/Annotators") + processor.to(device) else: processor = processor() return processor