Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
mashb1t committed Jul 7, 2024
2 parents 4752166 + 58559bd commit 236766b
Show file tree
Hide file tree
Showing 23 changed files with 365 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_container.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
type=edge,branch=main
- name: Build and push Docker image
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__
*.partial
*.onnx
sorted_styles.json
hash_cache.txt
/input
/cache
/language/default.json
Expand Down
5 changes: 4 additions & 1 deletion args_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
help="Enables automatic description of uov and enhance image when prompt is empty", default=False)

args_parser.parser.add_argument("--always-download-new-model", action='store_true',
help="Always download newer models ", default=False)
help="Always download newer models", default=False)

args_parser.parser.add_argument("--rebuild-hash-cache", action='store_true',
help="Generates missing model and LoRA hashes.", default=False)

args_parser.parser.set_defaults(
disable_cuda_malloc=True,
Expand Down
2 changes: 1 addition & 1 deletion fooocus_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '2.5.2 (mashb1t)'
version = '2.6.0-rc1 (mashb1t)'
7 changes: 5 additions & 2 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def ini_args():
else:
print(f"[Cleanup] Failed to delete content of temp dir.")

def download_models(default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads):

def download_models(default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads, vae_downloads):
for file_name, url in vae_approx_filenames:
load_file_from_url(url=url, model_dir=config.path_vae_approx, file_name=file_name)

Expand Down Expand Up @@ -130,12 +131,14 @@ def download_models(default_model, previous_default_models, checkpoint_downloads
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
for file_name, url in lora_downloads.items():
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
for file_name, url in vae_downloads.items():
load_file_from_url(url=url, model_dir=config.path_vae, file_name=file_name)

return default_model, checkpoint_downloads


config.default_base_model_name, config.checkpoint_downloads = download_models(
config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
config.embeddings_downloads, config.lora_downloads)
config.embeddings_downloads, config.lora_downloads, config.vae_downloads)

from webui import *
70 changes: 70 additions & 0 deletions ldm_patched/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,4 +835,74 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n
else:
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2)

return x


@torch.no_grad()
def sample_restart(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
step_id = 0

def heun_step(x, old_sigma, new_sigma, second_order=True):
nonlocal step_id
denoised = model(x, old_sigma * s_in, **extra_args)
d = to_d(x, old_sigma, denoised)
if callback is not None:
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
dt = new_sigma - old_sigma
if new_sigma == 0 or not second_order:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
d_2 = to_d(x_2, new_sigma, denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
step_id += 1
return x

steps = sigmas.shape[0] - 1
if restart_list is None:
if steps >= 20:
restart_steps = 9
restart_times = 1
if steps >= 36:
restart_steps = steps // 4
restart_times = 2
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
else:
restart_list = {}

restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}

step_list = []
for i in range(len(sigmas) - 1):
step_list.append((sigmas[i], sigmas[i + 1]))
if i + 1 in restart_list:
restart_steps, restart_times, restart_max = restart_list[i + 1]
min_idx = i + 1
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
if max_idx < min_idx:
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0:
restart_times -= 1
step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))

last_sigma = None
for old_sigma, new_sigma in tqdm(step_list, disable=disable):
if last_sigma is None:
last_sigma = old_sigma
elif last_sigma < old_sigma:
x = x + torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
x = heun_step(x, old_sigma, new_sigma)
last_sigma = new_sigma

return x
2 changes: 1 addition & 1 deletion ldm_patched/modules/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N

KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5", "restart"]

class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
Expand Down
38 changes: 36 additions & 2 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
import modules.flags
import modules.sdxl_styles
from modules.hash_cache import load_cache_from_file, save_cache_to_file

from modules.model_loader import load_file_from_url
from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var
Expand Down Expand Up @@ -445,6 +446,12 @@ def init_temp_path(path: str | None, default_path: str) -> str:
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
)
vae_downloads = get_config_item_or_set_default(
key='vae_downloads',
default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
)
available_aspect_ratios = get_config_item_or_set_default(
key='available_aspect_ratios',
default_value=modules.flags.sdxl_aspect_ratios,
Expand All @@ -463,6 +470,12 @@ def init_temp_path(path: str | None, default_path: str) -> str:
validator=lambda x: x in modules.flags.inpaint_engine_versions,
expected_type=str
)
default_inpaint_method = get_config_item_or_set_default(
key='default_inpaint_method',
default_value=modules.flags.inpaint_option_default,
validator=lambda x: x in modules.flags.inpaint_options,
expected_type=str
)
default_cfg_tsnr = get_config_item_or_set_default(
key='default_cfg_tsnr',
default_value=7.0,
Expand Down Expand Up @@ -602,7 +615,7 @@ def init_temp_path(path: str | None, default_path: str) -> str:

config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [[True, 'None', 1.0] for _ in range(default_max_lora_number - len(default_loras))]

# mapping config to meta parameter
# mapping config to meta parameter
possible_preset_keys = {
"default_model": "base_model",
"default_refiner": "refiner_model",
Expand All @@ -618,6 +631,7 @@ def init_temp_path(path: str | None, default_path: str) -> str:
"default_sampler": "sampler",
"default_scheduler": "scheduler",
"default_overwrite_step": "steps",
"default_overwrite_switch": "overwrite_switch",
"default_performance": "performance",
"default_image_number": "image_number",
"default_prompt": "prompt",
Expand All @@ -628,7 +642,10 @@ def init_temp_path(path: str | None, default_path: str) -> str:
"checkpoint_downloads": "checkpoint_downloads",
"embeddings_downloads": "embeddings_downloads",
"lora_downloads": "lora_downloads",
"default_vae": "vae"
"vae_downloads": "vae_downloads",
"default_vae": "vae",
# "default_inpaint_method": "inpaint_method", # disabled so inpaint mode doesn't refresh after every preset change
"default_inpaint_engine_version": "inpaint_engine_version",
}

REWRITE_PRESET = False
Expand Down Expand Up @@ -875,3 +892,20 @@ def downloading_sam_vit_h():


update_files()
load_cache_from_file()

if args_manager.args.rebuild_hash_cache:
from modules.hash_cache import sha256_from_cache
from modules.util import get_file_from_folder_list

print('[Cache] Rebuilding hash cache')
for filename in model_filenames:
filepath = get_file_from_folder_list(filename, paths_checkpoints)
sha256_from_cache(filepath)
for filename in lora_filenames:
filepath = get_file_from_folder_list(filename, paths_loras)
sha256_from_cache(filepath)
print('[Cache] Done')

# write cache to file again for sorting and cleanup of invalid cache entries
save_cache_to_file()
3 changes: 2 additions & 1 deletion modules/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
"dpmpp_3m_sde_gpu": "",
"ddpm": "",
"lcm": "LCM",
"tcd": "TCD"
"tcd": "TCD",
"restart": "Restart"
}

SAMPLER_EXTRA = {
Expand Down
53 changes: 53 additions & 0 deletions modules/hash_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json
import os

from modules.util import sha256, HASH_SHA256_LENGTH

hash_cache_filename = 'hash_cache.txt'
hash_cache = {}


def sha256_from_cache(filepath):
global hash_cache
if filepath not in hash_cache:
hash_value = sha256(filepath)
hash_cache[filepath] = hash_value
save_cache_to_file(filepath, hash_value)

return hash_cache[filepath]


def load_cache_from_file():
global hash_cache

try:
if os.path.exists(hash_cache_filename):
with open(hash_cache_filename, 'rt', encoding='utf-8') as fp:
for line in fp:
entry = json.loads(line)
for filepath, hash_value in entry.items():
if not os.path.exists(filepath) or not isinstance(hash_value, str) and len(hash_value) != HASH_SHA256_LENGTH:
print(f'[Cache] Skipping invalid cache entry: {filepath}')
continue
hash_cache[filepath] = hash_value
except Exception as e:
print(f'[Cache] Loading failed: {e}')


def save_cache_to_file(filename=None, hash_value=None):
global hash_cache

if filename is not None and hash_value is not None:
items = [(filename, hash_value)]
mode = 'at'
else:
items = sorted(hash_cache.items())
mode = 'wt'

try:
with open(hash_cache_filename, mode, encoding='utf-8') as fp:
for filepath, hash_value in items:
json.dump({filepath: hash_value}, fp)
fp.write('\n')
except Exception as e:
print(f'[Cache] Saving failed: {e}')
54 changes: 39 additions & 15 deletions modules/meta_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
import fooocus_version
import modules.config
import modules.sdxl_styles
from modules import hash_cache
from modules.flags import MetadataScheme, Performance, Steps
from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, sha256
from modules.hash_cache import sha256_from_cache
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list

re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")

hash_cache = {}


def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool, inpaint_mode: str):
loaded_parameter_dict = raw_metadata
if isinstance(raw_metadata, str):
loaded_parameter_dict = json.loads(raw_metadata)
Expand Down Expand Up @@ -49,6 +49,8 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results)
get_str('vae', 'VAE', loaded_parameter_dict, results)
get_seed('seed', 'Seed', loaded_parameter_dict, results)
get_inpaint_engine_version('inpaint_engine_version', 'Inpaint Engine Version', loaded_parameter_dict, results, inpaint_mode)
get_inpaint_method('inpaint_method', 'Inpaint Mode', loaded_parameter_dict, results)

if is_generating:
results.append(gr.update())
Expand Down Expand Up @@ -160,6 +162,36 @@ def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, d
results.append(gr.update())


def get_inpaint_engine_version(key: str, fallback: str | None, source_dict: dict, results: list, inpaint_mode: str, default=None) -> str | None:
try:
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str) and h in modules.flags.inpaint_engine_versions
if inpaint_mode != modules.flags.inpaint_option_detail:
results.append(h)
else:
results.append(gr.update())
results.append(h)
return h
except:
results.append(gr.update())
results.append('empty')
return None


def get_inpaint_method(key: str, fallback: str | None, source_dict: dict, results: list, default=None) -> str | None:
try:
h = source_dict.get(key, source_dict.get(fallback, default))
assert isinstance(h, str) and h in modules.flags.inpaint_options
results.append(h)
for i in range(modules.config.default_enhance_tabs):
results.append(h)
return h
except:
results.append(gr.update())
for i in range(modules.config.default_enhance_tabs):
results.append(gr.update())


def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
try:
h = source_dict.get(key, source_dict.get(fallback, default))
Expand Down Expand Up @@ -215,14 +247,6 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, p
results.append(1)


def get_sha256(filepath):
global hash_cache
if filepath not in hash_cache:
hash_cache[filepath] = sha256(filepath)

return hash_cache[filepath]


def parse_meta_from_preset(preset_content):
assert isinstance(preset_content, dict)
preset_prepared = {}
Expand Down Expand Up @@ -289,18 +313,18 @@ def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_p
self.base_model_name = Path(base_model_name).stem

base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints)
self.base_model_hash = get_sha256(base_model_path)
self.base_model_hash = sha256_from_cache(base_model_path)

if refiner_model_name not in ['', 'None']:
self.refiner_model_name = Path(refiner_model_name).stem
refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints)
self.refiner_model_hash = get_sha256(refiner_model_path)
self.refiner_model_hash = sha256_from_cache(refiner_model_path)

self.loras = []
for (lora_name, lora_weight) in loras:
if lora_name != 'None':
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras)
lora_hash = get_sha256(lora_path)
lora_hash = sha256_from_cache(lora_path)
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
self.vae_name = Path(vae_name).stem

Expand Down
Loading

0 comments on commit 236766b

Please sign in to comment.