-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodnet.py
More file actions
81 lines (68 loc) · 2.42 KB
/
modnet.py
File metadata and controls
81 lines (68 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import sys
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import cv2
from PIL import Image
import numpy as np
sys.path.append("./BackgroundReplace/")
from src.models.modnet import MODNet
torch_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dir_path = os.path.dirname(os.path.realpath(__file__))
print('Load pre-trained MODNet...')
pretrained_ckpt = dir_path + '/BackgroundReplace/model/modnet_webcam_portrait_matting.ckpt'
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
GPU = True if torch.cuda.device_count() > 0 else False
if GPU:
print('Use GPU...')
modnet = modnet.cuda()
modnet.load_state_dict(torch.load(pretrained_ckpt))
else:
print('Use CPU...')
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
modnet.eval()
def replaceBackground(frame_np, orgw, orgh, background_image):
frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
org = frame_np.copy()
frame_np = cv2.resize(frame_np, (512, 512), cv2.INTER_AREA)
frame_PIL = Image.fromarray(frame_np)
frame_tensor = torch_transforms(frame_PIL)
frame_tensor = frame_tensor[None, :, :, :]
if GPU:
frame_tensor = frame_tensor.cuda()
with torch.no_grad():
_, _, matte_tensor = modnet(frame_tensor, True)
matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
matte_np = cv2.resize(matte_np, (orgw, orgh))
fg_np = (matte_np * org).astype(np.uint8)
bg_frame = ((1 - matte_np) * background_image).astype(np.uint8)
frame1 = fg_np + bg_frame
return frame1
def removeBackground(frame_np):
frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
org = frame_np.copy()
orgh, orgw,_ = org.shape
frame_np = cv2.resize(frame_np, (512, 512), cv2.INTER_AREA)
frame_PIL = Image.fromarray(frame_np)
frame_tensor = torch_transforms(frame_PIL)
frame_tensor = frame_tensor[None, :, :, :]
if GPU:
frame_tensor = frame_tensor.cuda()
with torch.no_grad():
_, _, matte_tensor = modnet(frame_tensor, True)
matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
matte_np = cv2.resize(matte_np, (orgw, orgh))
matte_np1 = 1 - matte_np
fg_np = (matte_np1 * org).astype(np.uint8)
fg_np = cv2.cvtColor(fg_np, cv2.COLOR_RGB2BGR)
matte_np = cv2.cvtColor(matte_np, cv2.COLOR_BGR2GRAY)
return fg_np, matte_np * 255