From d0c1b93403b595f3b120c2af7c7d8960fb9abb9b Mon Sep 17 00:00:00 2001 From: user Date: Fri, 9 Aug 2024 07:45:59 -0700 Subject: [PATCH] change iterator syntax and explicitly set legacy align_corners so it works on torch 2.4/python 3.12 update the resnet encoder for pytorch 2.4 update to torch 2.4 np.int is decprecated --- datasets/mono_dataset.py | 12 +++++-- kitti_utils.py | 2 +- networks/resnet_encoder.py | 71 +++++++++++++++++++++----------------- trainer.py | 13 ++++--- 4 files changed, 58 insertions(+), 40 deletions(-) diff --git a/datasets/mono_dataset.py b/datasets/mono_dataset.py index a381934ca..79c400ae9 100644 --- a/datasets/mono_dataset.py +++ b/datasets/mono_dataset.py @@ -54,7 +54,8 @@ def __init__(self, self.height = height self.width = width self.num_scales = num_scales - self.interp = Image.ANTIALIAS + #self.interp = Image.ANTIALIAS ??? + self.interp = Image.Resampling.LANCZOS self.frame_idxs = frame_idxs @@ -173,8 +174,13 @@ def __getitem__(self, index): inputs[("inv_K", scale)] = torch.from_numpy(inv_K) if do_color_aug: - color_aug = transforms.ColorJitter.get_params( - self.brightness, self.contrast, self.saturation, self.hue) + # Create a ColorJitter transform with the desired parameters + color_aug = transforms.ColorJitter( + brightness=self.brightness, + contrast=self.contrast, + saturation=self.saturation, + hue=self.hue + ) else: color_aug = (lambda x: x) diff --git a/kitti_utils.py b/kitti_utils.py index ac2fdc920..379fa0bc4 100644 --- a/kitti_utils.py +++ b/kitti_utils.py @@ -83,7 +83,7 @@ def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False): # project to image depth = np.zeros((im_shape[:2])) - depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] + depth[velo_pts_im[:, 1].astype(np.int32), velo_pts_im[:, 0].astype(np.int32)] = velo_pts_im[:, 2] # find the duplicate points and choose the closest depth inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) diff --git a/networks/resnet_encoder.py b/networks/resnet_encoder.py index 9c94418d3..f3c103653 100644 --- a/networks/resnet_encoder.py +++ b/networks/resnet_encoder.py @@ -10,27 +10,16 @@ import torch import torch.nn as nn -import torchvision.models as models import torch.utils.model_zoo as model_zoo +import torchvision.models as models +from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, ResNet18_Weights, ResNet50_Weights - -class ResNetMultiImageInput(models.ResNet): - """Constructs a resnet model with varying number of input images. - Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py - """ - def __init__(self, block, layers, num_classes=1000, num_input_images=1): - super(ResNetMultiImageInput, self).__init__(block, layers) - self.inplanes = 64 +class ResNetMultiImageInput(ResNet): + """Constructs a ResNet model with a varying number of input images.""" + def __init__(self, block, layers, num_input_images=1, **kwargs): + super(ResNetMultiImageInput, self).__init__(block, layers, **kwargs) self.conv1 = nn.Conv2d( num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -38,26 +27,44 @@ def __init__(self, block, layers, num_classes=1000, num_input_images=1): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) - -def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): - """Constructs a ResNet model. +def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1, progress=True): + """Constructs a ResNet model with varying number of input images. Args: - num_layers (int): Number of resnet layers. Must be 18 or 50 - pretrained (bool): If True, returns a model pre-trained on ImageNet - num_input_images (int): Number of frames stacked as input + num_layers (int): Number of ResNet layers. Must be 18 or 50. + pretrained (bool): If True, returns a model pre-trained on ImageNet. + num_input_images (int): Number of frames stacked as input. + progress (bool): If True, displays a progress bar of the download. """ - assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" - blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] - block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] - model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) + assert num_layers in [18, 50], "Can only run with 18 or 50 layer ResNet" + + layers = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] + block = {18: BasicBlock, 50: Bottleneck}[num_layers] + + # Choose the correct weights based on the num_layers + if num_layers == 18: + weights_enum = ResNet18_Weights if pretrained else None + weights = ResNet18_Weights.DEFAULT if pretrained else None + elif num_layers == 50: + weights_enum = ResNet50_Weights if pretrained else None + weights = ResNet50_Weights.DEFAULT if pretrained else None + + model = ResNetMultiImageInput(block, layers, num_input_images=num_input_images) if pretrained: - loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) - loaded['conv1.weight'] = torch.cat( - [loaded['conv1.weight']] * num_input_images, 1) / num_input_images - model.load_state_dict(loaded) - return model + state_dict = weights.get_state_dict(progress=progress, check_hash=True) + + # Modify the conv1 weight to accommodate multiple input images + conv1_weight = state_dict['conv1.weight'] + if num_input_images > 1: + # Repeat the conv1 weights for the additional image channels + conv1_weight = conv1_weight.repeat(1, num_input_images, 1, 1) / num_input_images + + state_dict['conv1.weight'] = conv1_weight + + # Load the modified state dict into the model + model.load_state_dict(state_dict, strict=False) + return model class ResnetEncoder(nn.Module): """Pytorch module for a resnet encoder diff --git a/trainer.py b/trainer.py index f370603dc..8fc7ff273 100644 --- a/trainer.py +++ b/trainer.py @@ -56,10 +56,14 @@ def __init__(self, options): self.models["encoder"].to(self.device) self.parameters_to_train += list(self.models["encoder"].parameters()) + self.models["depth"] = networks.DepthDecoder( self.models["encoder"].num_ch_enc, self.opt.scales) + print("sending model to device") self.models["depth"].to(self.device) + print("done sending model to device") self.parameters_to_train += list(self.models["depth"].parameters()) + if self.use_pose_net: if self.opt.pose_model_type == "separate_resnet": @@ -193,7 +197,6 @@ def train(self): def run_epoch(self): """Run a single epoch of training and validation """ - self.model_lr_scheduler.step() print("Training") self.set_train() @@ -207,6 +210,7 @@ def run_epoch(self): self.model_optimizer.zero_grad() losses["loss"].backward() self.model_optimizer.step() + self.model_lr_scheduler.step() duration = time.time() - before_op_time @@ -322,10 +326,10 @@ def val(self): """ self.set_eval() try: - inputs = self.val_iter.next() + inputs = next(self.val_iter) except StopIteration: self.val_iter = iter(self.val_loader) - inputs = self.val_iter.next() + inputs = next(self.val_iter) with torch.no_grad(): outputs, losses = self.process_batch(inputs) @@ -384,7 +388,8 @@ def generate_images_pred(self, inputs, outputs): outputs[("color", frame_id, scale)] = F.grid_sample( inputs[("color", frame_id, source_scale)], outputs[("sample", frame_id, scale)], - padding_mode="border") + padding_mode="border", + align_corners=True) if not self.opt.disable_automasking: outputs[("color_identity", frame_id, scale)] = \