Skip to content

Commit f93c574

Browse files
committed
Add get_feature()
1 parent f403a11 commit f93c574

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

main.py

+10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def image_loader(img_path):
3232
return img_tensor
3333

3434

35+
def get_feature(vgg19, img_tensor, feature_id):
36+
feature_tensor = vgg19.features[:feature_id](img_tensor)
37+
feature = feature_tensor.data.squeeze().cpu().numpy().transpose(LEFT_SHIFT)
38+
39+
return feature
40+
41+
3542
def main(config):
3643
device = torch.device(('cuda:' + str(config.gpu)) if config.cuda else 'cpu')
3744

@@ -44,6 +51,9 @@ def main(config):
4451
vgg19 = torchvision.models.vgg19(pretrained=True)
4552
vgg19.to(device)
4653

54+
feat5S = get_feature(vgg19, imgS, FEATURE_IDS[4])
55+
feat5R = get_feature(vgg19, imgR, FEATURE_IDS[4])
56+
4757
# FastGuidedFilter
4858
# labOrigS = torch.from_numpy(color.rgb2lab(np.array(origS)).transpose(RIGHT_SHIFT)).float()
4959
rgbOrigS = transforms.ToTensor()(origS)

utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import torch
32

43

54
class PatchMatch:

0 commit comments

Comments
 (0)