8
8
import torch
9
9
from torch import nn
10
10
from torchvision .transforms import functional as TF
11
- from pytti .image_models import DifferentiableImage , EMAImage
11
+
12
+ # from pytti.image_models import DifferentiableImage
13
+ from pytti .image_models .ema import EMAImage , EMAParametersDict
12
14
from PIL import Image
13
15
from torch .nn import functional as F
14
16
@@ -44,8 +46,8 @@ def load_dip(input_depth, num_scales, offset_type, offset_groups, device):
44
46
return dip_net
45
47
46
48
47
- # class DeepImagePrior(EMAImage):
48
- class DeepImagePrior (DifferentiableImage ):
49
+ class DeepImagePrior (EMAImage ):
50
+ # class DeepImagePrior(DifferentiableImage):
49
51
"""
50
52
https://github.com/nousr/deep-image-prior/
51
53
"""
@@ -69,7 +71,14 @@ def __init__(
69
71
device = "cuda" ,
70
72
** kwargs ,
71
73
):
72
- super ().__init__ (width * scale , height * scale )
74
+ # super(super(EMAImage)).__init__()
75
+ nn .Module .__init__ (self )
76
+ super ().__init__ (
77
+ width = width * scale ,
78
+ height = height * scale ,
79
+ decay = ema_val ,
80
+ device = device ,
81
+ )
73
82
net = load_dip (
74
83
input_depth = input_depth ,
75
84
num_scales = num_scales ,
@@ -85,20 +94,38 @@ def __init__(
85
94
# z = torch.cat(get_non_offset_params(net), get_offset_params(net))
86
95
# logger.debug(z.shape)
87
96
# super().__init__(width * scale, height * scale, z, ema_val)
88
- self .net = net
97
+ # self.net = net
89
98
# self.tensor = self.net.params()
90
99
self .output_axes = ("n" , "s" , "y" , "x" )
91
100
self .scale = scale
92
101
self .device = device
93
102
94
- self ._net_input = torch .randn ([1 , input_depth , width , height ], device = device )
103
+ # self._net_input = torch.randn([1, input_depth, width, height], device=device)
95
104
96
105
self .lr = lr
97
106
self .offset_lr_fac = offset_lr_fac
98
107
# self._params = [
99
108
# {'params': get_non_offset_params(net), 'lr': lr},
100
109
# {'params': get_offset_params(net), 'lr': lr * offset_lr_fac}
101
110
# ]
111
+ # z = {
112
+ # 'non_offset':get_non_offset_params(net),
113
+ # 'offset':get_offset_params(net),
114
+ # }
115
+ self .net = net
116
+ self ._net_input = torch .randn ([1 , input_depth , width , height ], device = device )
117
+
118
+ self .image_representation_parameters = EMAParametersDict (
119
+ z = self .net , decay = ema_val , device = device
120
+ )
121
+
122
+ # super().__init__(
123
+ # width = width * scale,
124
+ # height = height * scale,
125
+ # tensor = z,
126
+ # decay = ema_val,
127
+ # device=device,
128
+ # )
102
129
103
130
# def get_image_tensor(self):
104
131
def decode_tensor (self ):
@@ -129,17 +156,34 @@ def get_latent_tensor(self, detach=False):
129
156
return params
130
157
131
158
def clone (self ):
132
- # dummy = super().__init__ (*self.image_shape)
159
+ # dummy = VQGANImage (*self.image_shape)
133
160
# with torch.no_grad():
134
- # #dummy.tensor.set_(self.tensor.clone())
135
- # dummy.net.copy_(self.net)
136
- # dummy.accum.set_(self.accum.clone())
137
- # dummy.biased.set_(self.biased.clone())
138
- # dummy.average.set_(self.average.clone())
139
- # dummy.decay = self.decay
140
- dummy = deepcopy (self )
161
+ # dummy.representation_parameters.set_(self.representation_parameters.clone())
162
+ # dummy.accum.set_(self.accum.clone())
163
+ # dummy.biased.set_(self.biased.clone())
164
+ # dummy.average.set_(self.average.clone())
165
+ # dummy.decay = self.decay
166
+ # return dummy
167
+ dummy = DeepImagePrior (* self .image_shape )
168
+ with torch .no_grad ():
169
+ # dummy.representation_parameters.set_(self.representation_parameters.clone())
170
+ dummy .image_representation_parameters .set_ (
171
+ self .image_representation_parameters .clone ()
172
+ )
141
173
return dummy
142
174
175
+ # def clone(self):
176
+ # # dummy = super().__init__(*self.image_shape)
177
+ # # with torch.no_grad():
178
+ # # #dummy.tensor.set_(self.tensor.clone())
179
+ # # dummy.net.copy_(self.net)
180
+ # # dummy.accum.set_(self.accum.clone())
181
+ # # dummy.biased.set_(self.biased.clone())
182
+ # # dummy.average.set_(self.average.clone())
183
+ # # dummy.decay = self.decay
184
+ # dummy = deepcopy(self)
185
+ # return dummy
186
+
143
187
def encode_random (self ):
144
188
pass
145
189
0 commit comments