diff --git a/README.md b/README.md index bdcb422..69454f0 100755 --- a/README.md +++ b/README.md @@ -72,6 +72,11 @@ swap_face_single(img1_fn, img2_fn, app, swapper, enhance=True, enhancer='REAL-ES - REAL-ESRGAN 4x - REAL-ESRGAN 8x +## GPU Support + +- cuda + **_(set 'device=cuda' to run with gpu)_** + ## Acknowledgments This project uses the InsightFace library and ONNX model for face analysis and swapping. Thanks to the developers of these libraries for their contributions. diff --git a/faceswap.py b/faceswap.py index 5142af1..615a03b 100755 --- a/faceswap.py +++ b/faceswap.py @@ -14,8 +14,12 @@ def validate_image(img): if not img.lower().endswith(('.jpg', '.jpeg', '.png')): raise ValueError(f'Image {img} is not a valid image file') +def cpu_warning(device): + if device == "cpu": + print("Using CPU for face enhancer. If you have a GPU, you can set device='cuda' to speed up the process. You can also set enhance=False to skip the enhancement.") + def swap_n_show(img1_fn, img2_fn, app, swapper, - plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x'): + plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x',device="cpu"): validate_image(img1_fn) validate_image(img2_fn) @@ -41,7 +45,8 @@ def swap_n_show(img1_fn, img2_fn, app, swapper, img1_ = swapper.get(img1_, face1, face2, paste_back=True) img2_ = swapper.get(img2_, face2, face1, paste_back=True) if enhance: - model, model_runner = load_face_enhancer_model(enhancer) + cpu_warning(device) + model, model_runner = load_face_enhancer_model(enhancer,device) img1_ = model_runner(img1_, model) img2_ = model_runner(img2_, model) fig, axs = plt.subplots(1, 2, figsize=(10, 5)) @@ -55,7 +60,7 @@ def swap_n_show(img1_fn, img2_fn, app, swapper, def swap_n_show_same_img(img1_fn, app, swapper, plot_before=False, - plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x'): + plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x',device="cpu"): validate_image(img1_fn) img1 = cv2.imread(img1_fn) @@ -75,7 +80,8 @@ def swap_n_show_same_img(img1_fn, img1_ = swapper.get(img1_, face1, face2, paste_back=True) img1_ = swapper.get(img1_, face2, face1, paste_back=True) if enhance: - model, model_runner = load_face_enhancer_model(enhancer) + cpu_warning(device) + model, model_runner = load_face_enhancer_model(enhancer,device) img1_ = model_runner(img1_, model) fig, ax = plt.subplots(1, 1, figsize=(10, 5)) ax.imshow(img1_[:,:,::-1]) @@ -84,7 +90,7 @@ def swap_n_show_same_img(img1_fn, return img1_ def swap_face_single(img1_fn, img2_fn, app, swapper, - plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x'): + plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x',device="cpu"): validate_image(img1_fn) validate_image(img2_fn) @@ -108,7 +114,8 @@ def swap_face_single(img1_fn, img2_fn, app, swapper, if plot_after: img1_ = swapper.get(img1_, face1, face2, paste_back=True) if enhance: - model, model_runner = load_face_enhancer_model(enhancer) + cpu_warning(device) + model, model_runner = load_face_enhancer_model(enhancer,device) img1_ = model_runner(img1_, model) # Save the image output_fn = os.path.join('outputs', os.path.basename(img1_fn)) diff --git a/main.py b/main.py index 0470204..3585bb9 100755 --- a/main.py +++ b/main.py @@ -28,4 +28,4 @@ # swap_n_show_same_img(img1_fn, app, swapper) # Add face to an image -swap_face_single(img1_fn, img2_fn, app, swapper, enhance=True, enhancer='REAL-ESRGAN 2x') \ No newline at end of file +swap_face_single(img1_fn, img2_fn, app, swapper, enhance=True, enhancer='REAL-ESRGAN 2x',device="cpu") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7058aa0..3806509 100755 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ opencv-python-headless>=4.7.0.72 onnx==1.14.0 onnxruntime==1.15.0 gfpgan==1.3.8 -timm==0.9.2 \ No newline at end of file +timm==0.9.2 +torch==2.0.1 \ No newline at end of file diff --git a/requirements_gpu.txt b/requirements_gpu.txt new file mode 100644 index 0000000..7b7a0f0 --- /dev/null +++ b/requirements_gpu.txt @@ -0,0 +1,14 @@ +cv +matplotlib +gdown +gradio>=3.33.1 +insightface==0.7.3 +moviepy>=1.0.3 +numpy +opencv-python>=4.7.0.72 +opencv-python-headless>=4.7.0.72 +onnx==1.14.0 +onnxruntime-gpu==1.15.0 +gfpgan==1.3.8 +timm==0.9.2 +torch==2.0.1 \ No newline at end of file