Skip to content

Commit

Permalink
Merge branch 'develop' into feature/add-nsfw-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
mashb1t authored May 18, 2024
2 parents 49795fe + 33fa175 commit 2d327bb
Show file tree
Hide file tree
Showing 23 changed files with 237 additions and 81 deletions.
3 changes: 3 additions & 0 deletions args_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
help="Disables downloading models for presets", default=False)

args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
help="Disables automatic description of uov images when prompt is empty", default=False)

args_parser.parser.add_argument("--always-download-new-model", action='store_true',
help="Always download newer models ", default=False)

Expand Down
2 changes: 1 addition & 1 deletion css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,6 @@ progress::after {
background-color: #fff8;
font-family: monospace;
text-align: center;
border-radius-top: 5px;
border-radius: 5px 5px 0px 0px;
display: none; /* remove this to enable tooltip in preview image */
}
1 change: 1 addition & 0 deletions docker.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Docker specified environments are there. They are used by 'entrypoint.sh'
|CMDARGS|Arguments for [entry_with_update.py](entry_with_update.py) which is called by [entrypoint.sh](entrypoint.sh)|
|config_path|'config.txt' location|
|config_example_path|'config_modification_tutorial.txt' location|
|HF_MIRROR| huggingface mirror site domain|

You can also use the same json key names and values explained in the 'config_modification_tutorial.txt' as the environments.
See examples in the [docker-compose.yml](docker-compose.yml)
Expand Down
92 changes: 54 additions & 38 deletions extras/vae_interpose.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,85 @@
# https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py

import os
import torch

import safetensors.torch as sf
import torch
import torch.nn as nn
import ldm_patched.modules.model_management

import ldm_patched.modules.model_management
from ldm_patched.modules.model_patcher import ModelPatcher
from modules.config import path_vae_approx


class Block(nn.Module):
def __init__(self, size):
class ResBlock(nn.Module):
"""Block with residuals"""

def __init__(self, ch):
super().__init__()
self.join = nn.ReLU()
self.norm = nn.BatchNorm2d(ch)
self.long = nn.Sequential(
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)

def forward(self, x):
y = self.long(x)
z = self.join(y + x)
return z
x = self.norm(x)
return self.join(self.long(x) + x)


class ExtractBlock(nn.Module):
"""Increase no. of channels by [out/in]"""

class Interposer(nn.Module):
def __init__(self):
def __init__(self, ch_in, ch_out):
super().__init__()
self.chan = 4
self.hid = 128

self.head_join = nn.ReLU()
self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1)
self.head_long = nn.Sequential(
nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
self.join = nn.ReLU()
self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.long = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)

def forward(self, x):
return self.join(self.long(x) + self.short(x))


class InterposerModel(nn.Module):
"""Main neural network"""

def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.ch_mid = ch_mid
self.blocks = blocks
self.scale = scale

self.head = ExtractBlock(self.ch_in, self.ch_mid)
self.core = nn.Sequential(
Block(self.hid),
Block(self.hid),
Block(self.hid),
)
self.tail = nn.Sequential(
nn.ReLU(),
nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
nn.Upsample(scale_factor=self.scale, mode="nearest"),
*[ResBlock(self.ch_mid) for _ in range(blocks)],
nn.BatchNorm2d(self.ch_mid),
nn.SiLU(),
)
self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1)

def forward(self, x):
y = self.head_join(
self.head_long(x) +
self.head_short(x)
)
y = self.head(x)
z = self.core(y)
return self.tail(z)


vae_approx_model = None
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors')
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v4.0.safetensors')


def parse(x):
Expand All @@ -72,7 +88,7 @@ def parse(x):
x_origin = x.clone()

if vae_approx_model is None:
model = Interposer()
model = InterposerModel()
model.eval()
sd = sf.load_file(vae_approx_filename)
model.load_state_dict(sd)
Expand Down
37 changes: 37 additions & 0 deletions javascript/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,43 @@ document.addEventListener("DOMContentLoaded", function() {
initStylePreviewOverlay();
});

var onAppend = function(elem, f) {
var observer = new MutationObserver(function(mutations) {
mutations.forEach(function(m) {
if (m.addedNodes.length) {
f(m.addedNodes);
}
});
});
observer.observe(elem, {childList: true});
}

function addObserverIfDesiredNodeAvailable(querySelector, callback) {
var elem = document.querySelector(querySelector);
if (!elem) {
window.setTimeout(() => addObserverIfDesiredNodeAvailable(querySelector, callback), 1000);
return;
}

onAppend(elem, callback);
}

/**
* Show reset button on toast "Connection errored out."
*/
addObserverIfDesiredNodeAvailable(".toast-wrap", function(added) {
added.forEach(function(element) {
if (element.innerText.includes("Connection errored out.")) {
window.setTimeout(function() {
document.getElementById("reset_button").classList.remove("hidden");
document.getElementById("generate_button").classList.add("hidden");
document.getElementById("skip_button").classList.add("hidden");
document.getElementById("stop_button").classList.add("hidden");
});
}
});
});

/**
* Add a ctrl+enter as a shortcut to start a generation
*/
Expand Down
4 changes: 4 additions & 0 deletions language/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"Generate": "Generate",
"Skip": "Skip",
"Stop": "Stop",
"Reconnect": "Reconnect",
"Input Image": "Input Image",
"Advanced": "Advanced",
"Upscale or Variation": "Upscale or Variation",
Expand Down Expand Up @@ -59,6 +60,7 @@
"\ud83d\udcda History Log": "\uD83D\uDCDA History Log",
"Image Style": "Image Style",
"Fooocus V2": "Fooocus V2",
"Random Style": "Random Style",
"Default (Slightly Cinematic)": "Default (Slightly Cinematic)",
"Fooocus Masterpiece": "Fooocus Masterpiece",
"Fooocus Photograph": "Fooocus Photograph",
Expand Down Expand Up @@ -341,6 +343,8 @@
"sgm_uniform": "sgm_uniform",
"simple": "simple",
"ddim_uniform": "ddim_uniform",
"VAE": "VAE",
"Default (model)": "Default (model)",
"Forced Overwrite of Sampling Step": "Forced Overwrite of Sampling Step",
"Set as -1 to disable. For developer debugging.": "Set as -1 to disable. For developer debugging.",
"Forced Overwrite of Refiner Switch Step": "Forced Overwrite of Refiner Switch Step",
Expand Down
8 changes: 6 additions & 2 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def prepare_environment():
vae_approx_filenames = [
('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
('xl-to-v1_interposer-v3.1.safetensors',
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
('xl-to-v1_interposer-v4.0.safetensors',
'https://huggingface.co/mashb1t/misc/resolve/main/xl-to-v1_interposer-v4.0.safetensors')
]


Expand All @@ -80,6 +80,10 @@ def ini_args():
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
print("Set device to:", args.gpu_device_id)

if args.hf_mirror is not None :
os.environ['HF_MIRROR'] = str(args.hf_mirror)
print("Set hf_mirror to:", args.hf_mirror)

from modules import config

os.environ['GRADIO_TEMP_DIR'] = config.temp_path
Expand Down
1 change: 1 addition & 0 deletions ldm_patched/modules/args_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--port", type=int, default=8188)
parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*")
parser.add_argument("--web-upload-size", type=float, default=100)
parser.add_argument("--hf-mirror", type=str, default=None)

parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append')
parser.add_argument("--output-path", type=str, default=None)
Expand Down
13 changes: 9 additions & 4 deletions ldm_patched/modules/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,13 @@ class EmptyClass:

return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)

def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None):
sd = ldm_patched.modules.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
clip = None
clipvision = None
vae = None
vae_filename = None
model = None
model_patcher = None
clip_target = None
Expand Down Expand Up @@ -462,8 +463,12 @@ class WeightsLoader(torch.nn.Module):
model.load_model_weights(sd, "model.diffusion_model.")

if output_vae:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
if vae_filename_param is None:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
else:
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param)
vae_filename = vae_filename_param
vae = VAE(sd=vae_sd)

if output_clip:
Expand All @@ -485,7 +490,7 @@ class WeightsLoader(torch.nn.Module):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)

return (model_patcher, clip, vae, clipvision)
return model_patcher, clip, vae, vae_filename, clipvision


def load_unet_state_dict(sd): #load unet in diffusers format
Expand Down
19 changes: 13 additions & 6 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def worker():
import args_manager

from extras.censor import censor_batch, censor_single
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
from modules.sdxl_styles import apply_style, get_random_style, apply_wildcards, fooocus_expansion, apply_arrays, random_style_name
from modules.private_logger import log
from extras.expansion import safe_str
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
Expand Down Expand Up @@ -172,6 +172,7 @@ def handler(async_task):
adaptive_cfg = args.pop()
sampler_name = args.pop()
scheduler_name = args.pop()
vae_name = args.pop()
overwrite_step = args.pop()
overwrite_switch = args.pop()
overwrite_width = args.pop()
Expand Down Expand Up @@ -434,7 +435,7 @@ def handler(async_task):
progressbar(async_task, 3, 'Loading models ...')
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner)
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)

progressbar(async_task, 3, 'Processing prompts ...')
tasks = []
Expand All @@ -455,8 +456,12 @@ def handler(async_task):
positive_basic_workloads = []
negative_basic_workloads = []

task_styles = style_selections.copy()
if use_style:
for s in style_selections:
for i, s in enumerate(task_styles):
if s == random_style_name:
s = get_random_style(task_rng)
task_styles[i] = s
p, n = apply_style(s, positive=task_prompt)
positive_basic_workloads = positive_basic_workloads + p
negative_basic_workloads = negative_basic_workloads + n
Expand Down Expand Up @@ -484,6 +489,7 @@ def handler(async_task):
negative_top_k=len(negative_basic_workloads),
log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
styles=task_styles
))

if use_expansion:
Expand Down Expand Up @@ -856,7 +862,7 @@ def callback(step, x0, x, total_steps, y):
d = [('Prompt', 'prompt', task['log_positive_prompt']),
('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
('Fooocus V2 Expansion', 'prompt_expansion', task['expansion']),
('Styles', 'styles', str(raw_style_selections)),
('Styles', 'styles', str(task['styles'] if not use_expansion else [fooocus_expansion] + task['styles'])),
('Performance', 'performance', performance_selection.value)]

if performance_selection.steps() != steps:
Expand All @@ -883,6 +889,7 @@ def callback(step, x0, x, total_steps, y):

d.append(('Sampler', 'sampler', sampler_name))
d.append(('Scheduler', 'scheduler', scheduler_name))
d.append(('VAE', 'vae', vae_name))
d.append(('Seed', 'seed', str(task['task_seed'])))

if freeu_enabled:
Expand All @@ -897,10 +904,10 @@ def callback(step, x0, x, total_steps, y):
metadata_parser = modules.meta_parser.get_metadata_parser(metadata_scheme)
metadata_parser.set_data(task['log_positive_prompt'], task['positive'],
task['log_negative_prompt'], task['negative'],
steps, base_model_name, refiner_model_name, loras)
steps, base_model_name, refiner_model_name, loras, vae_name)
d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
img_paths.append(log(x, d, metadata_parser, output_format))
img_paths.append(log(x, d, metadata_parser, output_format, task))

yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
except ldm_patched.modules.model_management.InterruptProcessingException as e:
Expand Down
Loading

0 comments on commit 2d327bb

Please sign in to comment.