File tree 1 file changed +14
-1
lines changed
1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -524,6 +524,15 @@ def parse_args(input_args=None):
524
524
default = 4 ,
525
525
help = ("The dimension of the LoRA update matrices." ),
526
526
)
527
+ parser .add_argument (
528
+ "--image_interpolation_mode" ,
529
+ type = str ,
530
+ default = "lanczos" ,
531
+ choices = [
532
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
533
+ ],
534
+ help = "The image interpolation method to use for resizing images." ,
535
+ )
527
536
528
537
if input_args is not None :
529
538
args = parser .parse_args (input_args )
@@ -601,9 +610,13 @@ def __init__(
601
610
else :
602
611
self .class_data_root = None
603
612
613
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
614
+ if interpolation is None :
615
+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
616
+
604
617
self .image_transforms = transforms .Compose (
605
618
[
606
- transforms .Resize (size , interpolation = transforms . InterpolationMode . BILINEAR ),
619
+ transforms .Resize (size , interpolation = interpolation ),
607
620
transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
608
621
transforms .ToTensor (),
609
622
transforms .Normalize ([0.5 ], [0.5 ]),
You can’t perform that action at this time.
0 commit comments