diff --git a/src/controlnet_aux/processor.py b/src/controlnet_aux/processor.py index 12cb6b0..38218a7 100644 --- a/src/controlnet_aux/processor.py +++ b/src/controlnet_aux/processor.py @@ -82,7 +82,7 @@ class Processor: - def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: + def __init__(self, processor_id: str, params: Optional[Dict] = None, checkpoints_dir: Optional[str] = None) -> None: """Processor that can be used to process images with controlnet aux processors Args: @@ -97,7 +97,7 @@ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}") self.processor_id = processor_id - self.processor = self.load_processor(self.processor_id) + self.processor = self.load_processor(self.processor_id, checkpoints_dir) # load default params self.params = MODEL_PARAMS[self.processor_id] @@ -105,7 +105,7 @@ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: if params: self.params.update(params) - def load_processor(self, processor_id: str) -> 'Processor': + def load_processor(self, processor_id: str, checkpoints_dir: Optional[str] = None) -> 'Processor': """Load controlnet aux processors Args: @@ -118,7 +118,9 @@ def load_processor(self, processor_id: str) -> 'Processor': # check if the proecssor is a checkpoint model if MODELS[processor_id]['checkpoint']: - processor = processor.from_pretrained("lllyasviel/Annotators") + checkpoints_dir = "lllyasviel/Annotators" if checkpoints_dir is None else checkpoints_dir + LOGGER.info(f"Loading {processor_id} from {checkpoints_dir}") + processor = processor.from_pretrained(checkpoints_dir) else: processor = processor() return processor