Skip to content

Commit

Permalink
Add GPU support
Browse files Browse the repository at this point in the history
Signed-off-by: kiranpranay <[email protected]>
  • Loading branch information
KiranPranay committed Jul 30, 2023
1 parent e89d99f commit 609419f
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 8 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ swap_face_single(img1_fn, img2_fn, app, swapper, enhance=True, enhancer='REAL-ES
- REAL-ESRGAN 4x
- REAL-ESRGAN 8x

## GPU Support

- cuda
**_(set 'device=cuda' to run with gpu)_**

## Acknowledgments

This project uses the InsightFace library and ONNX model for face analysis and swapping. Thanks to the developers of these libraries for their contributions.
Expand Down
19 changes: 13 additions & 6 deletions faceswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ def validate_image(img):
if not img.lower().endswith(('.jpg', '.jpeg', '.png')):
raise ValueError(f'Image {img} is not a valid image file')

def cpu_warning(device):
if device == "cpu":
print("Using CPU for face enhancer. If you have a GPU, you can set device='cuda' to speed up the process. You can also set enhance=False to skip the enhancement.")

def swap_n_show(img1_fn, img2_fn, app, swapper,
plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x'):
plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x',device="cpu"):

validate_image(img1_fn)
validate_image(img2_fn)
Expand All @@ -41,7 +45,8 @@ def swap_n_show(img1_fn, img2_fn, app, swapper,
img1_ = swapper.get(img1_, face1, face2, paste_back=True)
img2_ = swapper.get(img2_, face2, face1, paste_back=True)
if enhance:
model, model_runner = load_face_enhancer_model(enhancer)
cpu_warning(device)
model, model_runner = load_face_enhancer_model(enhancer,device)
img1_ = model_runner(img1_, model)
img2_ = model_runner(img2_, model)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
Expand All @@ -55,7 +60,7 @@ def swap_n_show(img1_fn, img2_fn, app, swapper,
def swap_n_show_same_img(img1_fn,
app, swapper,
plot_before=False,
plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x'):
plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x',device="cpu"):

validate_image(img1_fn)
img1 = cv2.imread(img1_fn)
Expand All @@ -75,7 +80,8 @@ def swap_n_show_same_img(img1_fn,
img1_ = swapper.get(img1_, face1, face2, paste_back=True)
img1_ = swapper.get(img1_, face2, face1, paste_back=True)
if enhance:
model, model_runner = load_face_enhancer_model(enhancer)
cpu_warning(device)
model, model_runner = load_face_enhancer_model(enhancer,device)
img1_ = model_runner(img1_, model)
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.imshow(img1_[:,:,::-1])
Expand All @@ -84,7 +90,7 @@ def swap_n_show_same_img(img1_fn,
return img1_

def swap_face_single(img1_fn, img2_fn, app, swapper,
plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x'):
plot_before=False, plot_after=True, enhance=False, enhancer='REAL-ESRGAN 2x',device="cpu"):

validate_image(img1_fn)
validate_image(img2_fn)
Expand All @@ -108,7 +114,8 @@ def swap_face_single(img1_fn, img2_fn, app, swapper,
if plot_after:
img1_ = swapper.get(img1_, face1, face2, paste_back=True)
if enhance:
model, model_runner = load_face_enhancer_model(enhancer)
cpu_warning(device)
model, model_runner = load_face_enhancer_model(enhancer,device)
img1_ = model_runner(img1_, model)
# Save the image
output_fn = os.path.join('outputs', os.path.basename(img1_fn))
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
# swap_n_show_same_img(img1_fn, app, swapper)

# Add face to an image
swap_face_single(img1_fn, img2_fn, app, swapper, enhance=True, enhancer='REAL-ESRGAN 2x')
swap_face_single(img1_fn, img2_fn, app, swapper, enhance=True, enhancer='REAL-ESRGAN 2x',device="cpu")
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ opencv-python-headless>=4.7.0.72
onnx==1.14.0
onnxruntime==1.15.0
gfpgan==1.3.8
timm==0.9.2
timm==0.9.2
torch==2.0.1
14 changes: 14 additions & 0 deletions requirements_gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
cv
matplotlib
gdown
gradio>=3.33.1
insightface==0.7.3
moviepy>=1.0.3
numpy
opencv-python>=4.7.0.72
opencv-python-headless>=4.7.0.72
onnx==1.14.0
onnxruntime-gpu==1.15.0
gfpgan==1.3.8
timm==0.9.2
torch==2.0.1

0 comments on commit 609419f

Please sign in to comment.