Skip to content

Commit f403a11

Browse files
committed
Add image_loader()
1 parent eb1c82b commit f403a11

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

main.py

+52-21
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,48 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import torch
6+
from torch import nn, optim
7+
import torchvision.models
8+
from torchvision import transforms, utils
9+
from guided_filter_pytorch.guided_filter import FastGuidedFilter
110
from PIL import Image
211
from skimage import color
12+
from sklearn.cluster import KMeans
313
from sklearn.neighbors import NearestNeighbors
4-
from torch.autograd import Variable
5-
from torch import nn, optim
6-
from torchvision import models, transforms, utils
7-
from guided_filter_pytorch.guided_filter import FastGuidedFilter
8-
import numpy as np
9-
import argparse
10-
import os
1114

12-
import utils
15+
from utils import *
1316

1417

15-
parser = argparse.ArgumentParser(description="Neural Color Transfer between Images PyTorch")
16-
parser.add_argument('--source_image',type = str, default='image/3_Source1', help= "Source Image that has Content")
17-
parser.add_arguement('--reference_image', type=str, default='image/3_Reference', help= "Reference Image to Get Style")
18-
parser.add_arguement('--results_path',type=str, default='/results')
19-
parser.add_arguement('--processing_path', type=str, default='/processimage')
20-
parser.add_arguement('--gpu', type=int, default=0)
21-
parser.add_argument("--cuda", dest='feature', action='store_true')
22-
parser.set_defaults(cuda=False)
18+
FEATURE_IDS = [1, 6, 11, 20, 29]
19+
LEFT_SHIFT = (1, 2, 0)
20+
RIGHT_SHIFT = (2, 0, 1)
2321

24-
## need more arguements?
2522

23+
def image_loader(img_path):
24+
img = Image.open(img_path).convert("RGB")
25+
transform = transforms.Compose([
26+
transforms.ToTensor(),
27+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
28+
std=[0.229, 0.224, 0.225]),
29+
])
30+
img_tensor = transform(img).unsqueeze(0)
2631

27-
def main():
28-
args = parser.parse_args()
29-
torch.cuda.set_device(args.gpu)
30-
device = torch.device('cuda:{}'.format(args.gpu))
32+
return img_tensor
33+
34+
35+
def main(config):
36+
device = torch.device(('cuda:' + str(config.gpu)) if config.cuda else 'cpu')
37+
38+
imgS = image_loader(config.source).to(device)
39+
imgR = image_loader(config.reference).to(device)
40+
41+
imgS_np = imgS.squeeze().numpy().transpose(LEFT_SHIFT)
42+
imgR_np = imgR.squeeze().numpy().transpose(LEFT_SHIFT)
43+
44+
vgg19 = torchvision.models.vgg19(pretrained=True)
45+
vgg19.to(device)
3146

3247
# FastGuidedFilter
3348
# labOrigS = torch.from_numpy(color.rgb2lab(np.array(origS)).transpose(RIGHT_SHIFT)).float()
@@ -39,3 +54,19 @@ def main():
3954
lct.paramB.permute(RIGHT_SHIFT).unsqueeze(0).cpu(),
4055
rgbOrigS.unsqueeze(0)).squeeze()
4156

57+
58+
if __name__ == '__main__':
59+
parser = argparse.ArgumentParser(description="Neural Color Transfer between Images PyTorch")
60+
61+
parser.add_argument('--source', type=str, default='./image/3_Source1', help="Source Image that has Content")
62+
parser.add_argument('--reference', type=str, default='./image/3_Reference', help="Reference Image to Get Style")
63+
parser.add_argument('--results_dir', type=str, default='./results')
64+
parser.add_argument('--processing_dir', type=str, default='./processImage')
65+
parser.add_argument('--cuda', dest='feature', action='store_true')
66+
parser.add_argument('--gpu', type=int, default=0)
67+
parser.set_defaults(cuda=False)
68+
# need more arguments?
69+
70+
args = parser.parse_args()
71+
print(args)
72+
main(args)

utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33

4+
45
class PatchMatch:
56
def __init__(self, a, b, patch_size=3):
67
self.a = a

0 commit comments

Comments
 (0)