-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdl_nst.py
191 lines (152 loc) · 56.8 KB
/
dl_nst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# -*- coding: utf-8 -*-
"""DL NST.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1RXj573MfUZK3QhUchxsPrGPhFSfXY1xd
## Deep Learning with PyTorch : Neural Style Transfer
## Task 1 : Content and Style Image collection

"""
!pip install torch torchvision
!git clone https://github.com/parth1620/Project-NST.git
"""## Task 2 : Loading VGG Pretrained Model"""
import torch
from torchvision import models
vgg=models.vgg19(pretrained=True)
print(vgg)
vgg=vgg.features
print(vgg)
for parameters in vgg.parameters():
parameters.requires_grad_(False)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
vgg.to(device)
"""## Task 3 : Preprocess image
Torchvision models page : https://pytorch.org/docs/stable/torchvision/models.html
"""
from PIL import Image
from torchvision import transforms as T
def preprocess(img_path,max_size=500):
image=Image.open(img_path).convert('RGB')
if max(image.size)>max_size:
size=max_size
else:
size=max(image.size)
img_transforms=T.Compose([
T.Resize(size),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
image=img_transforms(image)
image=image.unsqueeze(0) #(3,224,224)->(1,3,224,224)
return image
content_p=preprocess('/content/Project-NST/content11.jpg')
style_p=preprocess('/content/Project-NST/style12.jpg')
content_p=content_p.to(device)
style_p=style_p.to(device)
print("Content shape",content_p.shape)
print("Style shape",style_p.shape)
"""## Task 4 : Deprocess image"""
import numpy as np
import matplotlib.pyplot as plt
def deprocess(tensor):
image=tensor.to('cpu').clone()
image=image.numpy()
image=image.squeeze(0)
image=image.transpose(1,2,0)
image=image*np.array([0.229,0.224,0.225])+np.array([0.485,0.456,0.406])
image=image.clip(0,1)
return image
content_d=deprocess(content_p)
style_d=deprocess(style_p)
print("deprocess content",content_d.shape)
print("deprocess style",style_d.shape)
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(20,10))
ax1.imshow(content_d)
ax2.imshow(style_d)
"""## Task 5 : Get content,style features and create gram matrix"""
def get_features(image,model):
layers={
'0' : 'conv1_1',
'5' : 'conv2_1',
"10":"conv3_1",
"19":"conv4_1",
"21":"conv4_2",#content feature
"28":"conv5_1"
}
x= image
Features={}
for name,layer in model._modules.items():
x=layer(x)
if name in layers:
Features[layers[name]]=x
return Features
content_f=get_features(content_p,vgg)
style_f=get_features(style_p,vgg)
"""

"""
def gram_matrix(tensor):
b,c,h,w=tensor.size()
tensor=tensor.view(c,h*w)
gram=torch.mm(tensor,tensor.t())
return gram
style_grams={layer:gram_matrix(style_f[layer]) for layer in style_f}
"""## Task 6 : Creating Style and Content loss function"""
def content_loss(target_conv4_2,content_conv4_2):
loss=torch.mean((target_conv4_2-content_conv4_2)**2)
return loss
style_weights={
'conv1_1':1.0,
'conv2_1':0.75,
'conv3_1':0.2,
'conv4_1':0.2,
'conv5_1':0.2
}
def style_loss(style_weights,target_features,style_grams):
loss=0
for layer in style_weights:
target_f=target_features[layer]
target_gram=gram_matrix(target_f)
style_gram=style_grams[layer]
b,c,h,w=target_f.shape
layer_loss=style_weights[layer]*torch.mean((target_gram-style_gram)**2)
loss+=layer_loss/(c*h*w)
return loss
target=content_p.clone().requires_grad_(True).to(device)
target_f=get_features(target,vgg)
print("content loss:",content_loss(target_f['conv4_2'],content_f['conv4_2']))
print("style loss:",style_loss(style_weights,target_f,style_grams))
"""## Task 7 : Training loop"""
from torch import optim
optimizer=optim.Adam([target],lr=0.003)
alpha=1
beta=1e5
epochs=3000
show_every=500
def total_loss(c_loss,s_loss,alpha,beta):
loss=alpha*c_loss+beta*s_loss
return loss
results=[]
for i in range(epochs):
target_f=get_features(target,vgg)
c_loss=content_loss(target_f['conv4_2'],content_f['conv4_2'])
s_loss=style_loss(style_weights,target_f,style_grams)
t_loss=total_loss(c_loss,s_loss,alpha,beta)
optimizer.zero_grad()
t_loss.backward()
optimizer.step()
if i% show_every ==0:
print("Total loss at each epoch{}:{}".format(i,t_loss))
results.append(deprocess(target.detach()))
plt.figure(figsize=(10,8))
for i in range(len(results)):
plt.subplot(3,2,i+1)
plt.imshow(results[i])
plt.show()
target_copy=deprocess(target.detach())
content_copy=deprocess(content_p)
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5))
ax1.imshow(target_copy)
ax2.imshow(content_copy)