-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c492079
commit 6a7e9ad
Showing
5 changed files
with
291 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
""" | ||
This project is developed by Haofan Wang to support face swap in single frame. Multi-frame will be supported soon! | ||
It is highly built on the top of insightface, sd-webui-roop and CodeFormer. | ||
""" | ||
|
||
import sys | ||
sys.path.append('./CodeFormer/CodeFormer') | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torchvision.transforms.functional import normalize | ||
|
||
from basicsr.utils import imwrite, img2tensor, tensor2img | ||
from basicsr.utils.download_util import load_file_from_url | ||
from facelib.utils.face_restoration_helper import FaceRestoreHelper | ||
from facelib.utils.misc import is_gray | ||
from basicsr.archs.rrdbnet_arch import RRDBNet | ||
from basicsr.utils.realesrgan_utils import RealESRGANer | ||
from basicsr.utils.registry import ARCH_REGISTRY | ||
|
||
import os | ||
import cv2 | ||
import copy | ||
import insightface | ||
import numpy as np | ||
from PIL import Image | ||
from typing import List, Union, Dict, Set, Tuple | ||
|
||
|
||
pretrain_model_url = { | ||
'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', | ||
'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', | ||
'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth', | ||
'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth' | ||
} | ||
# download weights | ||
if not os.path.exists('CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth'): | ||
load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/CodeFormer/weights/CodeFormer', progress=True, file_name=None) | ||
if not os.path.exists('CodeFormer/CodeFormer/weights/facelib/detection_Resnet50_Final.pth'): | ||
load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/CodeFormer/weights/facelib', progress=True, file_name=None) | ||
if not os.path.exists('CodeFormer/CodeFormer/weights/facelib/parsing_parsenet.pth'): | ||
load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/CodeFormer/weights/facelib', progress=True, file_name=None) | ||
if not os.path.exists('CodeFormer/CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'): | ||
load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/CodeFormer/weights/realesrgan', progress=True, file_name=None) | ||
|
||
def imread(img_path): | ||
img = cv2.imread(img_path) | ||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
return img | ||
|
||
# set enhancer with RealESRGAN | ||
def set_realesrgan(): | ||
half = True if torch.cuda.is_available() else False | ||
model = RRDBNet( | ||
num_in_ch=3, | ||
num_out_ch=3, | ||
num_feat=64, | ||
num_block=23, | ||
num_grow_ch=32, | ||
scale=2, | ||
) | ||
upsampler = RealESRGANer( | ||
scale=2, | ||
model_path="CodeFormer/CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth", | ||
model=model, | ||
tile=400, | ||
tile_pad=40, | ||
pre_pad=0, | ||
half=half, | ||
) | ||
return upsampler | ||
|
||
def getFaceSwapModel(model_path: str): | ||
model = insightface.model_zoo.get_model(model_path) | ||
return model | ||
|
||
def getFaceAnalyser(model_path: str, | ||
det_size=(320, 320)): | ||
face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints") | ||
face_analyser.prepare(ctx_id=0, det_size=det_size) | ||
return face_analyser | ||
|
||
def get_one_face(face_analyser, | ||
frame:np.ndarray): | ||
face = face_analyser.get(frame) | ||
try: | ||
return min(face, key=lambda x: x.bbox[0]) | ||
except ValueError: | ||
return None | ||
|
||
def get_many_faces(face_analyser, | ||
frame:np.ndarray): | ||
""" | ||
get faces from left to right by order | ||
""" | ||
try: | ||
face = face_analyser.get(frame) | ||
return sorted(face, key=lambda x: x.bbox[0]) | ||
except IndexError: | ||
return None | ||
|
||
def swap_face(face_swapper, | ||
source_face, | ||
target_face, | ||
temp_frame): | ||
""" | ||
paste source_face on target image | ||
""" | ||
return face_swapper.get(temp_frame, target_face, source_face, paste_back=True) | ||
|
||
def process(source_img: Union[Image.Image, List], | ||
target_img: Image.Image, | ||
model: str): | ||
|
||
# load face_analyser | ||
face_analyser = getFaceAnalyser(model) | ||
|
||
# load face_swapper | ||
model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) | ||
face_swapper = getFaceSwapModel(model_path) | ||
|
||
# read target image | ||
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) | ||
|
||
# detect faces that will be replaced in target_img | ||
target_faces = get_many_faces(face_analyser, target_img) | ||
if target_faces is not None: | ||
temp_frame = copy.deepcopy(target_img) | ||
if isinstance(source_img, list) and len(source_img) == len(target_faces): | ||
# replace faces in target image from the left to the right by order | ||
for i in range(len(target_faces)): | ||
target_face = target_faces[i] | ||
source_face = get_one_face(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR)) | ||
if source_face is None: | ||
raise Exception("No source face found!") | ||
temp_frame = swap_face(face_swapper, source_face, target_face, temp_frame) | ||
else: | ||
# replace all faces in target image to same source_face | ||
source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) | ||
source_face = get_one_face(face_analyser, source_img) | ||
if source_face is None: | ||
raise Exception("No source face found!") | ||
for target_face in target_faces: | ||
temp_frame = swap_face(face_swapper, source_face, target_face, temp_frame) | ||
result = temp_frame | ||
else: | ||
print("No target faces found!") | ||
|
||
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) | ||
return result_image | ||
|
||
def face_restoration(img, background_enhance, face_upsample, upscale, codeformer_fidelity): | ||
"""Run a single prediction on the model""" | ||
try: # global try | ||
# take the default setting for the demo | ||
has_aligned = False | ||
only_center_face = False | ||
draw_box = False | ||
detection_model = "retinaface_resnet50" | ||
|
||
background_enhance = background_enhance if background_enhance is not None else True | ||
face_upsample = face_upsample if face_upsample is not None else True | ||
upscale = upscale if (upscale is not None and upscale > 0) else 2 | ||
|
||
upscale = int(upscale) # convert type to int | ||
if upscale > 4: # avoid memory exceeded due to too large upscale | ||
upscale = 4 | ||
if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution | ||
upscale = 2 | ||
if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution | ||
upscale = 1 | ||
background_enhance = False | ||
face_upsample = False | ||
|
||
face_helper = FaceRestoreHelper( | ||
upscale, | ||
face_size=512, | ||
crop_ratio=(1, 1), | ||
det_model=detection_model, | ||
save_ext="png", | ||
use_parse=True, | ||
device=device, | ||
) | ||
bg_upsampler = upsampler if background_enhance else None | ||
face_upsampler = upsampler if face_upsample else None | ||
|
||
if has_aligned: | ||
# the input faces are already cropped and aligned | ||
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) | ||
face_helper.is_gray = is_gray(img, threshold=5) | ||
face_helper.cropped_faces = [img] | ||
else: | ||
face_helper.read_image(img) | ||
# get face landmarks for each face | ||
num_det_faces = face_helper.get_face_landmarks_5( | ||
only_center_face=only_center_face, resize=640, eye_dist_threshold=5 | ||
) | ||
# align and warp each face | ||
face_helper.align_warp_face() | ||
|
||
# face restoration for each cropped face | ||
for idx, cropped_face in enumerate(face_helper.cropped_faces): | ||
# prepare data | ||
cropped_face_t = img2tensor( | ||
cropped_face / 255.0, bgr2rgb=True, float32=True | ||
) | ||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | ||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) | ||
|
||
try: | ||
with torch.no_grad(): | ||
output = codeformer_net( | ||
cropped_face_t, w=codeformer_fidelity, adain=True | ||
)[0] | ||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) | ||
del output | ||
torch.cuda.empty_cache() | ||
except RuntimeError as error: | ||
print(f"Failed inference for CodeFormer: {error}") | ||
restored_face = tensor2img( | ||
cropped_face_t, rgb2bgr=True, min_max=(-1, 1) | ||
) | ||
|
||
restored_face = restored_face.astype("uint8") | ||
face_helper.add_restored_face(restored_face) | ||
|
||
# paste_back | ||
if not has_aligned: | ||
# upsample the background | ||
if bg_upsampler is not None: | ||
# Now only support RealESRGAN for upsampling background | ||
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] | ||
else: | ||
bg_img = None | ||
face_helper.get_inverse_affine(None) | ||
# paste each restored face to the input image | ||
if face_upsample and face_upsampler is not None: | ||
restored_img = face_helper.paste_faces_to_input_image( | ||
upsample_img=bg_img, | ||
draw_box=draw_box, | ||
face_upsampler=face_upsampler, | ||
) | ||
else: | ||
restored_img = face_helper.paste_faces_to_input_image( | ||
upsample_img=bg_img, draw_box=draw_box | ||
) | ||
|
||
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) | ||
return restored_img | ||
except Exception as error: | ||
print('Global exception', error) | ||
return None, None | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
source_img = [Image.open("./data/man1.jpeg"),Image.open("./data/man2.jpeg")] | ||
target_img = Image.open("./data/mans1.jpeg") | ||
|
||
# download from https://huggingface.co/deepinsight/inswapper/tree/main | ||
model = "./checkpoints/inswapper_128.onnx" | ||
result_image = process(source_img,target_img, model) | ||
|
||
face_restore = True | ||
if face_restore: | ||
# https://huggingface.co/spaces/sczhou/CodeFormer | ||
upsampler = set_realesrgan() | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(dim_embd=512, | ||
codebook_size=1024, | ||
n_head=8, | ||
n_layers=9, | ||
connect_list=["32", "64", "128", "256"], | ||
).to(device) | ||
ckpt_path = "CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth" | ||
checkpoint = torch.load(ckpt_path)["params_ema"] | ||
codeformer_net.load_state_dict(checkpoint) | ||
codeformer_net.eval() | ||
|
||
background_enhance = True | ||
face_upsample = True | ||
upscale = 1 # up to 4 | ||
codeformer_fidelity = 0.5 | ||
|
||
result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR) | ||
result_image = face_restoration(result_image, background_enhance, face_upsample, upscale, codeformer_fidelity) | ||
result_image = Image.fromarray(result_image) | ||
|
||
result_image.save("result.png") |