Skip to content

Commit

Permalink
refactor: only use LoRA activate on handover to async worker, extract…
Browse files Browse the repository at this point in the history
… method
  • Loading branch information
mashb1t committed Mar 11, 2024
1 parent 532401d commit 57a0186
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
14 changes: 3 additions & 11 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def worker():
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, \
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras
from modules.upscaler import perform_upscale
from modules.flags import Performance
from modules.meta_parser import get_metadata_parser, MetadataScheme
Expand Down Expand Up @@ -124,14 +124,6 @@ def build_image_wall(async_task):
async_task.results = async_task.results + [wall]
return

def apply_enabled_loras(loras):
enabled_loras = []
for lora_enabled, lora_model, lora_weight in loras:
if lora_enabled:
enabled_loras.append([lora_model, lora_weight])

return enabled_loras

@torch.no_grad()
@torch.inference_mode()
def handler(async_task):
Expand All @@ -155,7 +147,7 @@ def handler(async_task):
base_model_name = args.pop()
refiner_model_name = args.pop()
refiner_switch = args.pop()
loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)])
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
input_image_checkbox = args.pop()
current_tab = args.pop()
uov_method = args.pop()
Expand Down
5 changes: 1 addition & 4 deletions modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ def refresh_loras(self, loras):

loras_to_load = []

for enabled, filename, weight in loras:
if not enabled:
continue

for filename, weight in loras:
if filename == 'None':
continue

Expand Down
4 changes: 2 additions & 2 deletions modules/default_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ldm_patched.modules.model_base import SDXL, SDXLRefiner
from modules.sample_hijack import clip_separate
from modules.util import get_file_from_folder_list
from modules.util import get_file_from_folder_list, get_enabled_loras


model_base = core.StableDiffusionModel()
Expand Down Expand Up @@ -254,7 +254,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
refresh_everything(
refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.config.default_base_model_name,
loras=modules.config.default_loras
loras=get_enabled_loras(modules.config.default_loras)
)


Expand Down
4 changes: 4 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,7 @@ def makedirs_with_log(path):
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')


def get_enabled_loras(loras: list) -> list:
return [[lora[1], lora[2]] for lora in loras if lora[0]]
5 changes: 5 additions & 0 deletions presets/lightning.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@
"default_refiner_switch": 0.5,
"default_loras": [
[
true,
"None",
1.0
],
[
true,
"None",
1.0
],
[
true,
"None",
1.0
],
[
true,
"None",
1.0
],
[
true,
"None",
1.0
]
Expand Down

0 comments on commit 57a0186

Please sign in to comment.