@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
173
173
return schedule
174
174
175
175
176
- def get_student_backbone_and_num_features (pre_process = True , post_process = True ):
176
+ def get_student_backbone_and_num_features (config , pre_process = True , post_process = True ):
177
177
args = get_args ()
178
178
179
179
if args .vision_backbone_type == 'vit' :
180
- student = VitBackbone (pre_process = pre_process ,
180
+ student = VitBackbone (config ,
181
+ pre_process = pre_process ,
181
182
post_process = post_process ,
182
183
drop_path_rate = 0.1 ,
183
184
single_token_output = True )
@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
194
195
195
196
return student , num_features
196
197
197
- def get_teacher_backbone_and_num_features (pre_process = True , post_process = True ):
198
+ def get_teacher_backbone_and_num_features (config , pre_process = True , post_process = True ):
198
199
args = get_args ()
199
200
200
201
if args .vision_backbone_type == 'vit' :
201
- teacher = VitBackbone (pre_process = pre_process ,
202
+ teacher = VitBackbone (config ,
203
+ pre_process = pre_process ,
202
204
post_process = post_process ,
203
205
single_token_output = True )
204
206
num_features = args .hidden_size
@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
215
217
216
218
217
219
class DINOPretrainModel (MegatronModule ):
218
- def __init__ (self , pre_process = True , post_process = True ):
220
+ def __init__ (self , config , pre_process = True , post_process = True ):
219
221
super (DINOPretrainModel , self ).__init__ ()
220
222
args = get_args ()
221
223
self .out_dim = 65536
@@ -234,7 +236,7 @@ def __init__(self, pre_process=True, post_process=True):
234
236
self .momentum_teacher = 0.996
235
237
236
238
student_backbone , num_features = \
237
- get_student_backbone_and_num_features (pre_process , post_process )
239
+ get_student_backbone_and_num_features (config , pre_process , post_process )
238
240
239
241
self .student = MultiCropWrapper (
240
242
student_backbone ,
@@ -249,7 +251,7 @@ def __init__(self, pre_process=True, post_process=True):
249
251
)
250
252
251
253
teacher_backbone , num_features = \
252
- get_teacher_backbone_and_num_features (pre_process , post_process )
254
+ get_teacher_backbone_and_num_features (config , pre_process , post_process )
253
255
self .teacher = MultiCropWrapper (
254
256
teacher_backbone ,
255
257
DINOHead (num_features , self .out_dim )
0 commit comments