Skip to content

Commit 06a2166

Browse files
committed
incorporated new EMA into DIP
1 parent e5f6af5 commit 06a2166

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

src/pytti/image_models/deep_image_prior.py

+58-14
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import torch
99
from torch import nn
1010
from torchvision.transforms import functional as TF
11-
from pytti.image_models import DifferentiableImage, EMAImage
11+
12+
# from pytti.image_models import DifferentiableImage
13+
from pytti.image_models.ema import EMAImage, EMAParametersDict
1214
from PIL import Image
1315
from torch.nn import functional as F
1416

@@ -44,8 +46,8 @@ def load_dip(input_depth, num_scales, offset_type, offset_groups, device):
4446
return dip_net
4547

4648

47-
# class DeepImagePrior(EMAImage):
48-
class DeepImagePrior(DifferentiableImage):
49+
class DeepImagePrior(EMAImage):
50+
# class DeepImagePrior(DifferentiableImage):
4951
"""
5052
https://github.com/nousr/deep-image-prior/
5153
"""
@@ -69,7 +71,14 @@ def __init__(
6971
device="cuda",
7072
**kwargs,
7173
):
72-
super().__init__(width * scale, height * scale)
74+
# super(super(EMAImage)).__init__()
75+
nn.Module.__init__(self)
76+
super().__init__(
77+
width=width * scale,
78+
height=height * scale,
79+
decay=ema_val,
80+
device=device,
81+
)
7382
net = load_dip(
7483
input_depth=input_depth,
7584
num_scales=num_scales,
@@ -85,20 +94,38 @@ def __init__(
8594
# z = torch.cat(get_non_offset_params(net), get_offset_params(net))
8695
# logger.debug(z.shape)
8796
# super().__init__(width * scale, height * scale, z, ema_val)
88-
self.net = net
97+
# self.net = net
8998
# self.tensor = self.net.params()
9099
self.output_axes = ("n", "s", "y", "x")
91100
self.scale = scale
92101
self.device = device
93102

94-
self._net_input = torch.randn([1, input_depth, width, height], device=device)
103+
# self._net_input = torch.randn([1, input_depth, width, height], device=device)
95104

96105
self.lr = lr
97106
self.offset_lr_fac = offset_lr_fac
98107
# self._params = [
99108
# {'params': get_non_offset_params(net), 'lr': lr},
100109
# {'params': get_offset_params(net), 'lr': lr * offset_lr_fac}
101110
# ]
111+
# z = {
112+
# 'non_offset':get_non_offset_params(net),
113+
# 'offset':get_offset_params(net),
114+
# }
115+
self.net = net
116+
self._net_input = torch.randn([1, input_depth, width, height], device=device)
117+
118+
self.image_representation_parameters = EMAParametersDict(
119+
z=self.net, decay=ema_val, device=device
120+
)
121+
122+
# super().__init__(
123+
# width = width * scale,
124+
# height = height * scale,
125+
# tensor = z,
126+
# decay = ema_val,
127+
# device=device,
128+
# )
102129

103130
# def get_image_tensor(self):
104131
def decode_tensor(self):
@@ -129,17 +156,34 @@ def get_latent_tensor(self, detach=False):
129156
return params
130157

131158
def clone(self):
132-
# dummy = super().__init__(*self.image_shape)
159+
# dummy = VQGANImage(*self.image_shape)
133160
# with torch.no_grad():
134-
# #dummy.tensor.set_(self.tensor.clone())
135-
# dummy.net.copy_(self.net)
136-
# dummy.accum.set_(self.accum.clone())
137-
# dummy.biased.set_(self.biased.clone())
138-
# dummy.average.set_(self.average.clone())
139-
# dummy.decay = self.decay
140-
dummy = deepcopy(self)
161+
# dummy.representation_parameters.set_(self.representation_parameters.clone())
162+
# dummy.accum.set_(self.accum.clone())
163+
# dummy.biased.set_(self.biased.clone())
164+
# dummy.average.set_(self.average.clone())
165+
# dummy.decay = self.decay
166+
# return dummy
167+
dummy = DeepImagePrior(*self.image_shape)
168+
with torch.no_grad():
169+
# dummy.representation_parameters.set_(self.representation_parameters.clone())
170+
dummy.image_representation_parameters.set_(
171+
self.image_representation_parameters.clone()
172+
)
141173
return dummy
142174

175+
# def clone(self):
176+
# # dummy = super().__init__(*self.image_shape)
177+
# # with torch.no_grad():
178+
# # #dummy.tensor.set_(self.tensor.clone())
179+
# # dummy.net.copy_(self.net)
180+
# # dummy.accum.set_(self.accum.clone())
181+
# # dummy.biased.set_(self.biased.clone())
182+
# # dummy.average.set_(self.average.clone())
183+
# # dummy.decay = self.decay
184+
# dummy = deepcopy(self)
185+
# return dummy
186+
143187
def encode_random(self):
144188
pass
145189

src/pytti/image_models/ema.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@ def reset(self):
5858
self.update()
5959

6060

61-
class EMAParametersDict(ImageRepresentationalParameters):
61+
# class EMAParametersDict(ImageRepresentationalParameters):
62+
class EMAParametersDict(nn.Module):
6263
"""
6364
LatentTensor with a singleton dimension for the EMAParameters
6465
"""
6566

6667
def __init__(self, z=None, decay=0.99, device=None):
67-
super(ImageRepresentationalParameters).__init__()
68+
# super(ImageRepresentationalParameters).__init__()
69+
super().__init__()
6870
self.decay = decay
6971
if device is None:
7072
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -74,9 +76,10 @@ def __init__(self, z=None, decay=0.99, device=None):
7476
def _new(self, z=None):
7577
if z is None:
7678
# I think this can all go in the constructor and doesn't need to call .to()
77-
z = torch.zeros(1, 3, self.height, self.width).to(
78-
device=self.device, memory_format=torch.channels_last
79-
)
79+
return nn.Parameter()
80+
# z = torch.zeros(1, 3, self.height, self.width).to(
81+
# device=self.device, memory_format=torch.channels_last
82+
# )
8083
# d_ = z
8184
d_ = {}
8285
if isinstance(z, EMAParametersDict):
@@ -147,7 +150,7 @@ def reset(self):
147150

148151

149152
class EMAImage(DifferentiableImage):
150-
def __init__(self, width, height, tensor, decay, device=None):
153+
def __init__(self, width, height, tensor=None, decay=0.99, device=None):
151154
super().__init__(width=width, height=height, device=device)
152155
self.image_representation_parameters = EMAParametersDict(
153156
z=tensor, decay=decay, device=device

0 commit comments

Comments
 (0)