From f338d5fc1611ec973657fe6a73763ffee79f6763 Mon Sep 17 00:00:00 2001
From: Manuel Schmid <dev@mash1t.de>
Date: Fri, 17 May 2024 23:56:02 +0200
Subject: [PATCH 1/4] feat: extract safety checker, remove dependency to
 diffusers

---
 .../stable_diffusion/safety_checker.py        | 126 ++++++++++++++++++
 modules/censor.py                             |  11 +-
 requirements_versions.txt                     |   1 -
 3 files changed, 129 insertions(+), 9 deletions(-)
 create mode 100644 extras/diffusers/pipelines/stable_diffusion/safety_checker.py

diff --git a/extras/diffusers/pipelines/stable_diffusion/safety_checker.py b/extras/diffusers/pipelines/stable_diffusion/safety_checker.py
new file mode 100644
index 0000000000..ea38bf038e
--- /dev/null
+++ b/extras/diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -0,0 +1,126 @@
+# from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+    normalized_image_embeds = nn.functional.normalize(image_embeds)
+    normalized_text_embeds = nn.functional.normalize(text_embeds)
+    return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class StableDiffusionSafetyChecker(PreTrainedModel):
+    config_class = CLIPConfig
+    main_input_name = "clip_input"
+
+    _no_split_modules = ["CLIPEncoderLayer"]
+
+    def __init__(self, config: CLIPConfig):
+        super().__init__(config)
+
+        self.vision_model = CLIPVisionModel(config.vision_config)
+        self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+        self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+        self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+        self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
+        self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
+
+    @torch.no_grad()
+    def forward(self, clip_input, images):
+        pooled_output = self.vision_model(clip_input)[1]  # pooled_output
+        image_embeds = self.visual_projection(pooled_output)
+
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+        special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
+        cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
+
+        result = []
+        batch_size = image_embeds.shape[0]
+        for i in range(batch_size):
+            result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+            # increase this value to create a stronger `nfsw` filter
+            # at the cost of increasing the possibility of filtering benign images
+            adjustment = 0.0
+
+            for concept_idx in range(len(special_cos_dist[0])):
+                concept_cos = special_cos_dist[i][concept_idx]
+                concept_threshold = self.special_care_embeds_weights[concept_idx].item()
+                result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+                if result_img["special_scores"][concept_idx] > 0:
+                    result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
+                    adjustment = 0.01
+
+            for concept_idx in range(len(cos_dist[0])):
+                concept_cos = cos_dist[i][concept_idx]
+                concept_threshold = self.concept_embeds_weights[concept_idx].item()
+                result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+                if result_img["concept_scores"][concept_idx] > 0:
+                    result_img["bad_concepts"].append(concept_idx)
+
+            result.append(result_img)
+
+        has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+        for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+            if has_nsfw_concept:
+                if torch.is_tensor(images) or torch.is_tensor(images[0]):
+                    images[idx] = torch.zeros_like(images[idx])  # black image
+                else:
+                    images[idx] = np.zeros(images[idx].shape)  # black image
+
+        if any(has_nsfw_concepts):
+            logger.warning(
+                "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+                " Try again with a different prompt and/or seed."
+            )
+
+        return images, has_nsfw_concepts
+
+    @torch.no_grad()
+    def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
+        pooled_output = self.vision_model(clip_input)[1]  # pooled_output
+        image_embeds = self.visual_projection(pooled_output)
+
+        special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+        cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+        # increase this value to create a stronger `nsfw` filter
+        # at the cost of increasing the possibility of filtering benign images
+        adjustment = 0.0
+
+        special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+        # special_scores = special_scores.round(decimals=3)
+        special_care = torch.any(special_scores > 0, dim=1)
+        special_adjustment = special_care * 0.01
+        special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+        concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+        # concept_scores = concept_scores.round(decimals=3)
+        has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+        images[has_nsfw_concepts] = 0.0  # black image
+
+        return images, has_nsfw_concepts
diff --git a/modules/censor.py b/modules/censor.py
index e2352218c1..ca47693ac6 100644
--- a/modules/censor.py
+++ b/modules/censor.py
@@ -1,10 +1,7 @@
 # modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
-
 import numpy as np
-import torch
-import modules.core as core
 
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from extras.diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 from transformers import AutoFeatureExtractor
 from PIL import Image
 import modules.config
@@ -16,8 +13,6 @@
 
 def numpy_to_pil(image):
     image = (image * 255).round().astype("uint8")
-
-    #pil_image = Image.fromarray(image, 'RGB')
     pil_image = Image.fromarray(image)
 
     return pil_image
@@ -27,7 +22,7 @@ def numpy_to_pil(image):
 def check_safety(x_image):
     global safety_feature_extractor, safety_checker
 
-    if safety_feature_extractor is None:
+    if safety_feature_extractor is None or safety_checker is None:
         safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
         safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
 
@@ -52,4 +47,4 @@ def censor_single(x):
 def censor_batch(images):
     images = [censor_single(image) for image in images]
 
-    return images
+    return images
\ No newline at end of file
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 5e9e85d6e6..b2111c1f5d 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -16,4 +16,3 @@ opencv-contrib-python==4.8.0.74
 httpx==0.24.1
 onnxruntime==1.16.3
 timm==0.9.2
-diffusers==0.25.1

From 0f78f8d8cc8800a910225bda9f2898c4b580ff77 Mon Sep 17 00:00:00 2001
From: Manuel Schmid <dev@mash1t.de>
Date: Fri, 17 May 2024 23:56:55 +0200
Subject: [PATCH 2/4] feat: make code compatible again after merge with main

---
 language/en.json        |  2 ++
 modules/async_worker.py | 32 +++++++++++++++++++++-----------
 modules/config.py       | 10 +++++-----
 webui.py                | 18 ++++++++++--------
 4 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/language/en.json b/language/en.json
index fefc79c477..d420a6ab46 100644
--- a/language/en.json
+++ b/language/en.json
@@ -54,6 +54,8 @@
     "Disable seed increment": "Disable seed increment",
     "Disable automatic seed increment when image number is > 1.": "Disable automatic seed increment when image number is > 1.",
     "Read wildcards in order": "Read wildcards in order",
+    "Black Out NSFW": "Black Out NSFW",
+    "Use black image if NSFW is detected.": "Use black image if NSFW is detected.",
     "\ud83d\udcda History Log": "\uD83D\uDCDA History Log",
     "Image Style": "Image Style",
     "Fooocus V2": "Fooocus V2",
diff --git a/modules/async_worker.py b/modules/async_worker.py
index 73cceadef7..0d95725c23 100644
--- a/modules/async_worker.py
+++ b/modules/async_worker.py
@@ -43,7 +43,7 @@ def worker():
     import fooocus_version
     import args_manager
 
-    from modules.censor import censor_batch
+    from modules.censor import censor_batch, censor_single
     from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
     from modules.private_logger import log
     from extras.expansion import safe_str
@@ -69,11 +69,11 @@ def progressbar(async_task, number, text):
         print(f'[Fooocus] {text}')
         async_task.yields.append(['preview', (number, text, None)])
 
-    def yield_result(async_task, imgs, do_not_show_finished_images=False, progressbar_index=13):
+    def yield_result(async_task, imgs, black_out_nsfw, censor=True, do_not_show_finished_images=False, progressbar_index=13):
         if not isinstance(imgs, list):
             imgs = [imgs]
 
-        if modules.config.default_black_out_nsfw or advanced_parameters.black_out_nsfw:
+        if censor and (modules.config.default_black_out_nsfw or black_out_nsfw):
             progressbar(async_task, progressbar_index, 'Checking for NSFW content ...')
             imgs = censor_batch(imgs)
 
@@ -165,6 +165,7 @@ def handler(async_task):
         disable_preview = args.pop()
         disable_intermediate_results = args.pop()
         disable_seed_increment = args.pop()
+        black_out_nsfw = args.pop()
         adm_scaler_positive = args.pop()
         adm_scaler_negative = args.pop()
         adm_scaler_end = args.pop()
@@ -577,8 +578,11 @@ def handler(async_task):
 
             if direct_return:
                 d = [('Upscale (Fast)', 'upscale_fast', '2x')]
+                if modules.config.default_black_out_nsfw or black_out_nsfw:
+                    progressbar(async_task, 100, 'Checking for NSFW content ...')
+                    uov_input_image = censor_single(uov_input_image)
                 uov_input_image_path = log(uov_input_image, d, output_format=output_format)
-                yield_result(async_task, uov_input_image_path, do_not_show_finished_images=True)
+                yield_result(async_task, uov_input_image_path, black_out_nsfw, False, do_not_show_finished_images=True)
                 return
 
             tiled = True
@@ -642,8 +646,7 @@ def handler(async_task):
             )
 
             if debugging_inpaint_preprocessor:
-                yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(),
-                             do_not_show_finished_images=True)
+                yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), black_out_nsfw, do_not_show_finished_images=True)
                 return
 
             progressbar(async_task, 13, 'VAE Inpaint encoding ...')
@@ -706,7 +709,7 @@ def handler(async_task):
                 cn_img = HWC3(cn_img)
                 task[0] = core.numpy_to_pytorch(cn_img)
                 if debugging_cn_preprocessor:
-                    yield_result(async_task, cn_img, do_not_show_finished_images=True)
+                    yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
                     return
             for task in cn_tasks[flags.cn_cpds]:
                 cn_img, cn_stop, cn_weight = task
@@ -718,7 +721,7 @@ def handler(async_task):
                 cn_img = HWC3(cn_img)
                 task[0] = core.numpy_to_pytorch(cn_img)
                 if debugging_cn_preprocessor:
-                    yield_result(async_task, cn_img, do_not_show_finished_images=True)
+                    yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
                     return
             for task in cn_tasks[flags.cn_ip]:
                 cn_img, cn_stop, cn_weight = task
@@ -729,7 +732,7 @@ def handler(async_task):
 
                 task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
                 if debugging_cn_preprocessor:
-                    yield_result(async_task, cn_img, do_not_show_finished_images=True)
+                    yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
                     return
             for task in cn_tasks[flags.cn_ip_face]:
                 cn_img, cn_stop, cn_weight = task
@@ -743,7 +746,7 @@ def handler(async_task):
 
                 task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
                 if debugging_cn_preprocessor:
-                    yield_result(async_task, cn_img, do_not_show_finished_images=True)
+                    yield_result(async_task, cn_img, black_out_nsfw, do_not_show_finished_images=True)
                     return
 
             all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
@@ -843,6 +846,12 @@ def callback(step, x0, x, total_steps, y):
                     imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
 
                 img_paths = []
+
+                if modules.config.default_black_out_nsfw or black_out_nsfw:
+                    progressbar(async_task, int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)),
+                                'Checking for NSFW content ...')
+                    imgs = censor_batch(imgs)
+
                 for x in imgs:
                     d = [('Prompt', 'prompt', task['log_positive_prompt']),
                          ('Negative Prompt', 'negative_prompt', task['log_negative_prompt']),
@@ -892,7 +901,8 @@ def callback(step, x0, x, total_steps, y):
                     d.append(('Metadata Scheme', 'metadata_scheme', metadata_scheme.value if save_metadata_to_images else save_metadata_to_images))
                     d.append(('Version', 'version', 'Fooocus v' + fooocus_version.version))
                     img_paths.append(log(x, d, metadata_parser, output_format))
-                yield_result(async_task, img_paths, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results, progressbar_index=int(15.0 + 85.0 * float((current_task_id + 1) * steps) / float(all_steps)))
+
+                yield_result(async_task, img_paths, black_out_nsfw, False, do_not_show_finished_images=len(tasks) == 1 or disable_intermediate_results)
             except ldm_patched.modules.model_management.InterruptProcessingException as e:
                 if async_task.last_stop == 'skip':
                     print('User skipped')
diff --git a/modules/config.py b/modules/config.py
index 2db23dbfda..5a18e96358 100644
--- a/modules/config.py
+++ b/modules/config.py
@@ -451,6 +451,11 @@ def init_temp_path(path: str | None, default_path: str) -> str:
     ],
     validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
 )
+default_black_out_nsfw = get_config_item_or_set_default(
+    key='default_black_out_nsfw',
+    default_value=False,
+    validator=lambda x: isinstance(x, bool)
+)
 default_save_metadata_to_images = get_config_item_or_set_default(
     key='default_save_metadata_to_images',
     default_value=False,
@@ -466,11 +471,6 @@ def init_temp_path(path: str | None, default_path: str) -> str:
     default_value='',
     validator=lambda x: isinstance(x, str)
 )
-default_black_out_nsfw = get_config_item_or_set_default(
-    key='default_black_out_nsfw',
-    default_value=False,
-    validator=lambda x: isinstance(x, bool)
-)
 
 example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
 
diff --git a/webui.py b/webui.py
index 29eed60687..ab6ad09130 100644
--- a/webui.py
+++ b/webui.py
@@ -445,6 +445,15 @@ def update_history_link():
                                                              value=False)
                         read_wildcards_in_order = gr.Checkbox(label="Read wildcards in order", value=False)
 
+                        black_out_nsfw = gr.Checkbox(label='Black Out NSFW',
+                                                     value=modules.config.default_black_out_nsfw,
+                                                     interactive=not modules.config.default_black_out_nsfw,
+                                                     info='Use black image if NSFW is detected.')
+
+                        black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x),
+                                              inputs=black_out_nsfw, outputs=disable_preview, queue=False,
+                                              show_progress=False)
+
                         if not args_manager.args.disable_metadata:
                             save_metadata_to_images = gr.Checkbox(label='Save Metadata to Images', value=modules.config.default_save_metadata_to_images,
                                                                   info='Adds parameters to generated images allowing manual regeneration.')
@@ -455,13 +464,6 @@ def update_history_link():
                             save_metadata_to_images.change(lambda x: gr.update(visible=x), inputs=[save_metadata_to_images], outputs=[metadata_scheme], 
                                                            queue=False, show_progress=False)
 
-                        black_out_nsfw = gr.Checkbox(label='Black Out NSFW', value=modules.config.default_black_out_nsfw,
-                                                     interactive=not modules.config.default_black_out_nsfw,
-                                                     info='Use black image if NSFW is detected.')
-
-                        black_out_nsfw.change(lambda x: gr.update(value=x, interactive=not x),
-                                     inputs=black_out_nsfw, outputs=disable_preview, queue=False, show_progress=False)
-
                     with gr.Tab(label='Control'):
                         debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False,
                                                                 info='See the results from preprocessors.')
@@ -640,7 +642,7 @@ def inpaint_mode_change(mode):
         ctrls += [input_image_checkbox, current_tab]
         ctrls += [uov_method, uov_input_image]
         ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
-        ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment]
+        ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment, black_out_nsfw]
         ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg]
         ctrls += [sampler_name, scheduler_name]
         ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength]

From 7568b72d9b8dd285238bb0c680b84d6a80fe7932 Mon Sep 17 00:00:00 2001
From: Manuel Schmid <dev@mash1t.de>
Date: Sat, 18 May 2024 01:59:15 +0200
Subject: [PATCH 3/4] feat: move censor to extras, optimize safety checker file
 handling

---
 .gitignore                                    |   1 -
 {modules => extras}/censor.py                 |  16 +-
 extras/safety_checker/configs/config.json     | 171 ++++++++++++++++++
 .../configs/preprocessor_config.json          |  20 ++
 .../models}/safety_checker.py                 |   0
 modules/async_worker.py                       |   2 +-
 modules/config.py                             |   8 +
 7 files changed, 211 insertions(+), 7 deletions(-)
 rename {modules => extras}/censor.py (65%)
 create mode 100644 extras/safety_checker/configs/config.json
 create mode 100644 extras/safety_checker/configs/preprocessor_config.json
 rename extras/{diffusers/pipelines/stable_diffusion => safety_checker/models}/safety_checker.py (100%)

diff --git a/.gitignore b/.gitignore
index e423ef81a9..859149866a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,7 +18,6 @@ config.txt
 config_modification_tutorial.txt
 user_path_config.txt
 user_path_config-deprecated.txt
-/models/safety_checker_models
 /modules/*.png
 /repositories
 /fooocus_env
diff --git a/modules/censor.py b/extras/censor.py
similarity index 65%
rename from modules/censor.py
rename to extras/censor.py
index ca47693ac6..2047db2461 100644
--- a/modules/censor.py
+++ b/extras/censor.py
@@ -1,12 +1,16 @@
 # modified version of https://github.com/AUTOMATIC1111/stable-diffusion-webui-nsfw-censor/blob/master/scripts/censor.py
 import numpy as np
+import os
 
-from extras.diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
+from extras.safety_checker.models.safety_checker import StableDiffusionSafetyChecker
+from transformers import CLIPFeatureExtractor, CLIPConfig
 from PIL import Image
 import modules.config
 
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_checker_repo_root = os.path.join(os.path.dirname(__file__), 'safety_checker')
+config_path = os.path.join(safety_checker_repo_root, "configs", "config.json")
+preprocessor_config_path = os.path.join(safety_checker_repo_root, "configs", "preprocessor_config.json")
+
 safety_feature_extractor = None
 safety_checker = None
 
@@ -23,8 +27,10 @@ def check_safety(x_image):
     global safety_feature_extractor, safety_checker
 
     if safety_feature_extractor is None or safety_checker is None:
-        safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
-        safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, cache_dir=modules.config.path_safety_checker_models)
+        safety_checker_model = modules.config.downloading_safety_checker_model()
+        safety_feature_extractor = CLIPFeatureExtractor.from_json_file(preprocessor_config_path)
+        clip_config = CLIPConfig.from_json_file(config_path)
+        safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_checker_model, config=clip_config)
 
     safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
     x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
diff --git a/extras/safety_checker/configs/config.json b/extras/safety_checker/configs/config.json
new file mode 100644
index 0000000000..aa454d2225
--- /dev/null
+++ b/extras/safety_checker/configs/config.json
@@ -0,0 +1,171 @@
+{
+  "_name_or_path": "clip-vit-large-patch14/",
+  "architectures": [
+    "SafetyChecker"
+  ],
+  "initializer_factor": 1.0,
+  "logit_scale_init_value": 2.6592,
+  "model_type": "clip",
+  "projection_dim": 768,
+  "text_config": {
+    "_name_or_path": "",
+    "add_cross_attention": false,
+    "architectures": null,
+    "attention_dropout": 0.0,
+    "bad_words_ids": null,
+    "bos_token_id": 0,
+    "chunk_size_feed_forward": 0,
+    "cross_attention_hidden_size": null,
+    "decoder_start_token_id": null,
+    "diversity_penalty": 0.0,
+    "do_sample": false,
+    "dropout": 0.0,
+    "early_stopping": false,
+    "encoder_no_repeat_ngram_size": 0,
+    "eos_token_id": 2,
+    "exponential_decay_length_penalty": null,
+    "finetuning_task": null,
+    "forced_bos_token_id": null,
+    "forced_eos_token_id": null,
+    "hidden_act": "quick_gelu",
+    "hidden_size": 768,
+    "id2label": {
+      "0": "LABEL_0",
+      "1": "LABEL_1"
+    },
+    "initializer_factor": 1.0,
+    "initializer_range": 0.02,
+    "intermediate_size": 3072,
+    "is_decoder": false,
+    "is_encoder_decoder": false,
+    "label2id": {
+      "LABEL_0": 0,
+      "LABEL_1": 1
+    },
+    "layer_norm_eps": 1e-05,
+    "length_penalty": 1.0,
+    "max_length": 20,
+    "max_position_embeddings": 77,
+    "min_length": 0,
+    "model_type": "clip_text_model",
+    "no_repeat_ngram_size": 0,
+    "num_attention_heads": 12,
+    "num_beam_groups": 1,
+    "num_beams": 1,
+    "num_hidden_layers": 12,
+    "num_return_sequences": 1,
+    "output_attentions": false,
+    "output_hidden_states": false,
+    "output_scores": false,
+    "pad_token_id": 1,
+    "prefix": null,
+    "problem_type": null,
+    "pruned_heads": {},
+    "remove_invalid_values": false,
+    "repetition_penalty": 1.0,
+    "return_dict": true,
+    "return_dict_in_generate": false,
+    "sep_token_id": null,
+    "task_specific_params": null,
+    "temperature": 1.0,
+    "tie_encoder_decoder": false,
+    "tie_word_embeddings": true,
+    "tokenizer_class": null,
+    "top_k": 50,
+    "top_p": 1.0,
+    "torch_dtype": null,
+    "torchscript": false,
+    "transformers_version": "4.21.0.dev0",
+    "typical_p": 1.0,
+    "use_bfloat16": false,
+    "vocab_size": 49408
+  },
+  "text_config_dict": {
+    "hidden_size": 768,
+    "intermediate_size": 3072,
+    "num_attention_heads": 12,
+    "num_hidden_layers": 12
+  },
+  "torch_dtype": "float32",
+  "transformers_version": null,
+  "vision_config": {
+    "_name_or_path": "",
+    "add_cross_attention": false,
+    "architectures": null,
+    "attention_dropout": 0.0,
+    "bad_words_ids": null,
+    "bos_token_id": null,
+    "chunk_size_feed_forward": 0,
+    "cross_attention_hidden_size": null,
+    "decoder_start_token_id": null,
+    "diversity_penalty": 0.0,
+    "do_sample": false,
+    "dropout": 0.0,
+    "early_stopping": false,
+    "encoder_no_repeat_ngram_size": 0,
+    "eos_token_id": null,
+    "exponential_decay_length_penalty": null,
+    "finetuning_task": null,
+    "forced_bos_token_id": null,
+    "forced_eos_token_id": null,
+    "hidden_act": "quick_gelu",
+    "hidden_size": 1024,
+    "id2label": {
+      "0": "LABEL_0",
+      "1": "LABEL_1"
+    },
+    "image_size": 224,
+    "initializer_factor": 1.0,
+    "initializer_range": 0.02,
+    "intermediate_size": 4096,
+    "is_decoder": false,
+    "is_encoder_decoder": false,
+    "label2id": {
+      "LABEL_0": 0,
+      "LABEL_1": 1
+    },
+    "layer_norm_eps": 1e-05,
+    "length_penalty": 1.0,
+    "max_length": 20,
+    "min_length": 0,
+    "model_type": "clip_vision_model",
+    "no_repeat_ngram_size": 0,
+    "num_attention_heads": 16,
+    "num_beam_groups": 1,
+    "num_beams": 1,
+    "num_hidden_layers": 24,
+    "num_return_sequences": 1,
+    "output_attentions": false,
+    "output_hidden_states": false,
+    "output_scores": false,
+    "pad_token_id": null,
+    "patch_size": 14,
+    "prefix": null,
+    "problem_type": null,
+    "pruned_heads": {},
+    "remove_invalid_values": false,
+    "repetition_penalty": 1.0,
+    "return_dict": true,
+    "return_dict_in_generate": false,
+    "sep_token_id": null,
+    "task_specific_params": null,
+    "temperature": 1.0,
+    "tie_encoder_decoder": false,
+    "tie_word_embeddings": true,
+    "tokenizer_class": null,
+    "top_k": 50,
+    "top_p": 1.0,
+    "torch_dtype": null,
+    "torchscript": false,
+    "transformers_version": "4.21.0.dev0",
+    "typical_p": 1.0,
+    "use_bfloat16": false
+  },
+  "vision_config_dict": {
+    "hidden_size": 1024,
+    "intermediate_size": 4096,
+    "num_attention_heads": 16,
+    "num_hidden_layers": 24,
+    "patch_size": 14
+  }
+}
diff --git a/extras/safety_checker/configs/preprocessor_config.json b/extras/safety_checker/configs/preprocessor_config.json
new file mode 100644
index 0000000000..5294955ff7
--- /dev/null
+++ b/extras/safety_checker/configs/preprocessor_config.json
@@ -0,0 +1,20 @@
+{
+  "crop_size": 224,
+  "do_center_crop": true,
+  "do_convert_rgb": true,
+  "do_normalize": true,
+  "do_resize": true,
+  "feature_extractor_type": "CLIPFeatureExtractor",
+  "image_mean": [
+    0.48145466,
+    0.4578275,
+    0.40821073
+  ],
+  "image_std": [
+    0.26862954,
+    0.26130258,
+    0.27577711
+  ],
+  "resample": 3,
+  "size": 224
+}
diff --git a/extras/diffusers/pipelines/stable_diffusion/safety_checker.py b/extras/safety_checker/models/safety_checker.py
similarity index 100%
rename from extras/diffusers/pipelines/stable_diffusion/safety_checker.py
rename to extras/safety_checker/models/safety_checker.py
diff --git a/modules/async_worker.py b/modules/async_worker.py
index 0d95725c23..fa10ff8ad8 100644
--- a/modules/async_worker.py
+++ b/modules/async_worker.py
@@ -43,7 +43,7 @@ def worker():
     import fooocus_version
     import args_manager
 
-    from modules.censor import censor_batch, censor_single
+    from extras.censor import censor_batch, censor_single
     from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
     from modules.private_logger import log
     from extras.expansion import safe_str
diff --git a/modules/config.py b/modules/config.py
index 5a18e96358..8b27724273 100644
--- a/modules/config.py
+++ b/modules/config.py
@@ -685,5 +685,13 @@ def downloading_upscale_model():
     )
     return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
 
+def downloading_safety_checker_model():
+    load_file_from_url(
+        url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin',
+        model_dir=path_safety_checker_models,
+        file_name='stable-diffusion-safety-checker.bin'
+    )
+    return os.path.join(path_safety_checker_models, 'stable-diffusion-safety-checker.bin')
+
 
 update_files()

From 49795fe0306149106c43cb75bd1eb8afc45badce Mon Sep 17 00:00:00 2001
From: Manuel Schmid <dev@mash1t.de>
Date: Sat, 18 May 2024 15:37:58 +0200
Subject: [PATCH 4/4] refactor: rename folder safety_checker_models to
 safety_checker

---
 .../put_safety_checker_models_here                          | 0
 modules/config.py                                           | 6 +++---
 2 files changed, 3 insertions(+), 3 deletions(-)
 rename models/{safety_checker_models => safety_checker}/put_safety_checker_models_here (100%)

diff --git a/models/safety_checker_models/put_safety_checker_models_here b/models/safety_checker/put_safety_checker_models_here
similarity index 100%
rename from models/safety_checker_models/put_safety_checker_models_here
rename to models/safety_checker/put_safety_checker_models_here
diff --git a/modules/config.py b/modules/config.py
index 8b27724273..73e33e4a03 100644
--- a/modules/config.py
+++ b/modules/config.py
@@ -195,7 +195,7 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa
 path_clip_vision = get_dir_or_set_default('path_clip_vision', '../models/clip_vision/')
 path_fooocus_expansion = get_dir_or_set_default('path_fooocus_expansion', '../models/prompt_expansion/fooocus_expansion')
 path_wildcards = get_dir_or_set_default('path_wildcards', '../wildcards/')
-path_safety_checker_models = get_dir_or_set_default('path_safety_checker_models', '../models/safety_checker_models/')
+path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/safety_checker/')
 path_outputs = get_path_output()
 
 
@@ -688,10 +688,10 @@ def downloading_upscale_model():
 def downloading_safety_checker_model():
     load_file_from_url(
         url='https://huggingface.co/mashb1t/misc/resolve/main/stable-diffusion-safety-checker.bin',
-        model_dir=path_safety_checker_models,
+        model_dir=path_safety_checker,
         file_name='stable-diffusion-safety-checker.bin'
     )
-    return os.path.join(path_safety_checker_models, 'stable-diffusion-safety-checker.bin')
+    return os.path.join(path_safety_checker, 'stable-diffusion-safety-checker.bin')
 
 
 update_files()