diff --git a/configs/e2e_faster_rcnn_R-50-C4_2x.yaml b/configs/e2e_faster_rcnn_R-50-C4_2x.yaml index d84fa49c..adcafe43 100644 --- a/configs/e2e_faster_rcnn_R-50-C4_2x.yaml +++ b/configs/e2e_faster_rcnn_R-50-C4_2x.yaml @@ -4,7 +4,7 @@ MODEL: FASTER_RCNN: True RESNETS: IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth' -NUM_GPUS: 8 +NUM_GPUS: 4 SOLVER: WEIGHT_DECAY: 0.0001 LR_POLICY: steps_with_decay diff --git a/configs/e2e_light_head_rcnn_R-50-C4.yml b/configs/e2e_light_head_rcnn_R-50-C4.yml new file mode 100644 index 00000000..9a3cb5c4 --- /dev/null +++ b/configs/e2e_light_head_rcnn_R-50-C4.yml @@ -0,0 +1,36 @@ +MODEL: + TYPE: generalized_rcnn + CONV_BODY: ResNet.ResNet50_conv4_body + LIGHT_HEAD_RCNN: True +RESNETS: + IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth' +NUM_GPUS: 4 +SOLVER: + TYPE: 'SGD' + WEIGHT_DECAY: 0.0001 + LR_POLICY: steps_with_decay + BASE_LR: 0.005 + GAMMA: 0.1 + # 2x schedule (note TRAIN.IMS_PER_BATCH: 1) + MAX_ITER: 360000 + STEPS: [0, 240000, 320000] +RPN: + RPN_ON: True +# CLS_ACTIVATION: softmax + SIZES: (32, 64, 128, 256, 512) +LIGHT_HEAD_RCNN: + ROI_XFORM_RESOLUTION: 7 + ROI_XFORM_METHOD: PSRoIPool +TRAIN: + SCALES: (800,) + MAX_SIZE: 1333 + IMS_PER_BATCH: 1 + FG_THRESH: 0.7 + BG_THRESH_HI: 0.5 + BATCH_SIZE_PER_IM: 1024 +TEST: + SCALE: 800 + MAX_SIZE: 1333 + NMS: 0.5 + RPN_PRE_NMS_TOP_N: 6000 + RPN_POST_NMS_TOP_N: 1000 diff --git a/configs/e2e_mask_rcnn_R-101-FPN_2x.yaml b/configs/e2e_mask_rcnn_R-101-FPN_2x.yaml index 49563e03..505c6eb3 100644 --- a/configs/e2e_mask_rcnn_R-101-FPN_2x.yaml +++ b/configs/e2e_mask_rcnn_R-101-FPN_2x.yaml @@ -5,7 +5,7 @@ MODEL: MASK_ON: True RESNETS: IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet101_caffe.pth' -NUM_GPUS: 8 +NUM_GPUS: 4 SOLVER: WEIGHT_DECAY: 0.0001 LR_POLICY: steps_with_decay @@ -41,3 +41,4 @@ TEST: NMS: 0.5 RPN_PRE_NMS_TOP_N: 1000 # Per FPN level RPN_POST_NMS_TOP_N: 1000 +VIS: True diff --git a/lib/core/config.py b/lib/core/config.py index f90f9e8a..9caafa0f 100644 --- a/lib/core/config.py +++ b/lib/core/config.py @@ -430,6 +430,8 @@ # Indicates the model makes instance mask predictions (as in Mask R-CNN) __C.MODEL.MASK_ON = False +__C.MODEL.LIGHT_HEAD_RCNN = False + # Indicates the model makes keypoint predictions (as in Mask R-CNN for # keypoints) __C.MODEL.KEYPOINTS_ON = False @@ -633,6 +635,14 @@ __C.FAST_RCNN.ROI_XFORM_RESOLUTION = 14 +# ---------------------------------------------------------------------------- # +#hw LIGHT_HEAD_RCNN options +# ---------------------------------------------------------------------------- # +__C.LIGHT_HEAD_RCNN = AttrDict() +__C.LIGHT_HEAD_RCNN.ROI_BOX_HEAD = '' +__C.LIGHT_HEAD_RCNN.MLP_HEAD_DIM = 1024 +__C.LIGHT_HEAD_RCNN.ROI_XFORM_RESOLUTION = 7 +__C.LIGHT_HEAD_RCNN.ROI_XFORM_METHOD = 'PSRoIPool' # ---------------------------------------------------------------------------- # # RPN options # ---------------------------------------------------------------------------- # diff --git a/lib/core/test.py b/lib/core/test.py index 3ca84908..89bde689 100644 --- a/lib/core/test.py +++ b/lib/core/test.py @@ -116,7 +116,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None): inputs, im_scale = _get_blobs(im, boxes, target_scale, target_max_size) - if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN: + if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN and not cfg.MODEL.LIGHT_HEAD_RCNN: v = np.array([1, 1e3, 1e6, 1e9, 1e12]) hashes = np.round(inputs['rois'] * cfg.DEDUP_BOXES).dot(v) _, index, inv_index = np.unique( @@ -126,7 +126,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None): boxes = boxes[index, :] # Add multi-level rois for FPN - if cfg.FPN.MULTILEVEL_ROIS and not cfg.MODEL.FASTER_RCNN: + if cfg.FPN.MULTILEVEL_ROIS and not cfg.MODEL.FASTER_RCNN and not cfg.MODEL.LIGHT_HEAD_RCNN: _add_multilevel_rois_for_test(inputs, 'rois') if cfg.PYTORCH_VERSION_LESS_THAN_040: @@ -138,7 +138,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None): return_dict = model(**inputs) - if cfg.MODEL.FASTER_RCNN: + if cfg.MODEL.FASTER_RCNN or cfg.MODEL.LIGHT_HEAD_RCNN: rois = return_dict['rois'].data.cpu().numpy() # unscale back to raw image space boxes = rois[:, 1:5] / im_scale @@ -168,7 +168,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None): # Simply repeat the boxes, once for each class pred_boxes = np.tile(boxes, (1, scores.shape[1])) - if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN: + if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN and not cfg.MODEL.LIGHT_HEAD_RCNN: # Map scores and predictions back to the original set of boxes scores = scores[inv_index, :] pred_boxes = pred_boxes[inv_index, :] diff --git a/lib/make.sh b/lib/make.sh index 8470621a..4b3f9bf2 100755 --- a/lib/make.sh +++ b/lib/make.sh @@ -50,6 +50,15 @@ nvcc -c -o roi_crop_cuda_kernel.cu.o roi_crop_cuda_kernel.cu \ cd ../ python build.py +# compile psroi_pooling +cd ../../ +cd model/psroi_pooling/src +echo "Compiling psroi pooling kernels by nvcc..." +nvcc -c -o psroi_pooling_kernel.cu.o psroi_pooling_kernel.cu \ + -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CUDA_ARCH +cd ../ +python build.py + # compile roi_align (based on Caffe2's implementation) cd ../../ cd modeling/roi_xfrom/roi_align/src @@ -58,3 +67,7 @@ nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu \ -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CUDA_ARCH cd ../ python build.py + + + + diff --git a/lib/model/psroi_align_pooling/__init__.py b/lib/model/psroi_align_pooling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_align_pooling/_ext/__init__.py b/lib/model/psroi_align_pooling/_ext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_align_pooling/_ext/psroi_align_pooling/__init__.py b/lib/model/psroi_align_pooling/_ext/psroi_align_pooling/__init__.py new file mode 100644 index 00000000..ad91e104 --- /dev/null +++ b/lib/model/psroi_align_pooling/_ext/psroi_align_pooling/__init__.py @@ -0,0 +1,15 @@ + +from torch.utils.ffi import _wrap_function +from ._psroi_align_pooling import lib as _lib, ffi as _ffi + +__all__ = [] +def _import_symbols(locals): + for symbol in dir(_lib): + fn = getattr(_lib, symbol) + if callable(fn): + locals[symbol] = _wrap_function(fn, _ffi) + else: + locals[symbol] = fn + __all__.append(symbol) + +_import_symbols(locals()) diff --git a/lib/model/psroi_align_pooling/build.py b/lib/model/psroi_align_pooling/build.py new file mode 100644 index 00000000..9124d8e0 --- /dev/null +++ b/lib/model/psroi_align_pooling/build.py @@ -0,0 +1,37 @@ +import os +import torch +from torch.utils.ffi import create_extension + +sources = [] +headers = [] +defines = [] +with_cuda = False + + +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/psroi_align_pooling_cuda.c'] + headers += ['src/psroi_align_pooling_cuda.h'] + defines += [('WITH_CUDA', None)] + with_cuda = True + + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +extra_objects = ['src/psroi_align_pooling_kernel.cu.o'] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] +print(extra_objects) + +ffi = create_extension( + '_ext.psroi_align_pooling', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=['-DDEBUG'] +) + +if __name__ == '__main__': + ffi.build() \ No newline at end of file diff --git a/lib/model/psroi_align_pooling/functions/__init__.py b/lib/model/psroi_align_pooling/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_align_pooling/functions/psroi_align_pooling.py b/lib/model/psroi_align_pooling/functions/psroi_align_pooling.py new file mode 100644 index 00000000..bb70d08c --- /dev/null +++ b/lib/model/psroi_align_pooling/functions/psroi_align_pooling.py @@ -0,0 +1,69 @@ +import torch +from torch.autograd import Function +from .._ext import psroi_align_pooling + +class PSRoiAlignPoolingFunction(Function): + def __init__(self,pooled_height,pooled_width,sample_height,sample_width,spatial_scale,group_size): + self.pooled_height = int(pooled_height) + self.pooled_width = int(pooled_width) + self.sample_height = int(sample_height) + self.sample_width = int(sample_width) + self.spatial_scale = float(spatial_scale) + self.group_size = int(group_size) + self.output = None + self.mapping_channel = None + self.argmax_position = None + self.rois = None + self.feature_size = None + self.output_dim = None + def forward(self, features, rois): + batch_size, num_channels, data_height, data_width = features.size() + self.output_dim = num_channels // self.pooled_height // self.pooled_width + # self.output_dim = num_channels + num_rois = rois.size()[0] + output = torch.zeros(num_rois, self.output_dim, self.pooled_height, self.pooled_width) + mapping_channel = torch.cuda.IntTensor(num_rois, self.output_dim, self.pooled_height, self.pooled_width).zero_() + argmax_position = torch.cuda.IntTensor(num_rois, self.output_dim, self.pooled_height, self.pooled_width).zero_() + output = output.cuda() + psroi_align_pooling.psroi_align_pooling_forward_cuda(self.pooled_height, + self.pooled_width, + self.sample_height, + self.sample_width, + self.spatial_scale, + self.group_size, + self.output_dim, + features, + rois, + output, + mapping_channel, + argmax_position + ) + self.output = output + self.mapping_channel = mapping_channel + self.argmax_position = argmax_position + self.rois = rois + self.feature_size = features.size() + + return output + + def backward(self, grad_output): + assert(self.feature_size is not None and grad_output.is_cuda) + + batch_size, num_channels, data_height, data_width = self.feature_size + + grad_input = torch.zeros(batch_size, num_channels, data_height, data_width).cuda() + # import pdb + # pdb.set_trace() + psroi_align_pooling.psroi_align_pooling_backward_cuda(self.pooled_height, + self.pooled_width, + self.sample_height, + self.sample_width, + self.spatial_scale, + self.group_size, + self.output_dim, + grad_output, + self.rois, + grad_input, + self.mapping_channel, + self.argmax_position) + return grad_input, None diff --git a/lib/model/psroi_align_pooling/make.sh b/lib/model/psroi_align_pooling/make.sh new file mode 100755 index 00000000..0a6bc492 --- /dev/null +++ b/lib/model/psroi_align_pooling/make.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +CUDA_PATH=/usr/local/cuda/ + +cd src +echo "Compiling psroi_align_pooling kernels by nvcc..." +nvcc -c -o psroi_align_pooling_kernel.cu.o psroi_align_pooling_kernel.cu.cc -x cu -Xcompiler -fPIC -arch=sm_60 + +cd ../ +python build.py diff --git a/lib/model/psroi_align_pooling/modules/__init__.py b/lib/model/psroi_align_pooling/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_align_pooling/modules/psroi_align_pooling.py b/lib/model/psroi_align_pooling/modules/psroi_align_pooling.py new file mode 100644 index 00000000..ab4ced27 --- /dev/null +++ b/lib/model/psroi_align_pooling/modules/psroi_align_pooling.py @@ -0,0 +1,17 @@ +from torch.nn.modules.module import Module +from ..functions.psroi_align_pooling import PSRoiAlignPoolingFunction + +class PSRoIAlignPool(Module): + def __init__(self, pooled_height, pooled_width,sample_height,sample_width, spatial_scale, group_size, output_dim): + super(PSRoIAlignPool, self).__init__() + + self.pooled_width = int(pooled_width) + self.pooled_height = int(pooled_height) + self.sample_height = int(sample_height) + self.sample_width = int(sample_width) + self.spatial_scale = float(spatial_scale) + self.group_size = int(group_size) + self.output_dim = int(output_dim) + + def forward(self, features, rois): + return PSRoiAlignPoolingFunction(self.pooled_height, self.pooled_width,self.sample_height,self.sample_width, self.spatial_scale, self.group_size, self.output_dim)(features, rois) diff --git a/lib/model/psroi_align_pooling/src/psroi_align_pooling_cuda.c b/lib/model/psroi_align_pooling/src/psroi_align_pooling_cuda.c new file mode 100644 index 00000000..7e6ebf3e --- /dev/null +++ b/lib/model/psroi_align_pooling/src/psroi_align_pooling_cuda.c @@ -0,0 +1,104 @@ +#include +#include +#include "psroi_align_pooling_kernel.h" + +extern THCState* state; + +int psroi_align_pooling_forward_cuda( + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float spatial_scale, + const int group_size, const int output_dim, + THCudaTensor* features, + THCudaTensor* rois, + THCudaTensor* output, + THCudaIntTensor* mapping_channel, + THCudaIntTensor* argmax_position) +{ + float* bottom_data = THCudaTensor_data(state, features); + float* bottom_rois = THCudaTensor_data(state, rois); + float* top_data = THCudaTensor_data(state, output); + int* top_mapping_channel = THCudaIntTensor_data(state, mapping_channel); + int* top_argmax_position = THCudaIntTensor_data(state, argmax_position); + //Get # of Rois + int num_rois = THCudaTensor_size(state, rois, 0); + int size_rois = THCudaTensor_size(state, rois, 1); + if(size_rois !=5){ + return 0; + } + + //Get # of batch_size + int batch_size = THCudaTensor_size(state, features, 0); + if (batch_size!=1) + { + return 0; + } + int num_channels = THCudaTensor_size(state, features, 1); + int data_height = THCudaTensor_size(state, features, 2); + int data_width = THCudaTensor_size(state, features, 3); + + cudaStream_t stream = THCState_getCurrentStream(state); + + + //Forward Kernel + return PSAlignPoolForwardLauncher( + bottom_data, spatial_scale, num_rois, + num_channels, data_height, data_width, + pooled_height, pooled_width, + sample_height, sample_width, + bottom_rois, + output_dim, group_size, top_data, + top_mapping_channel, top_argmax_position, + stream); +} + + +int psroi_align_pooling_backward_cuda( + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float spatial_scale, + const int group_size, const int output_dim, + THCudaTensor* top_grad, + THCudaTensor* rois, + THCudaTensor* bottom_grad, + THCudaIntTensor* mapping_channel, + THCudaIntTensor* argmax_position) +{ + float* top_grad_flat = THCudaTensor_data(state, top_grad); + float* rois_flat = THCudaTensor_data(state, rois); + + float* bottom_grad_flat = THCudaTensor_data(state, bottom_grad); + int* mapping_channel_flat = THCudaIntTensor_data(state,mapping_channel); + int* argmax_position_flat = THCudaIntTensor_data(state,argmax_position); + + int num_rois = THCudaTensor_size(state, rois, 0); + int size_rois = THCudaTensor_size(state, rois, 1); + if (size_rois != 5) + { + return 0; + } + + int batch_size = THCudaTensor_size(state, bottom_grad, 0); + if (batch_size != 1) + { + return 0; + } + int data_height = THCudaTensor_size(state, bottom_grad, 2); + int data_width = THCudaTensor_size(state, bottom_grad, 3); + int num_channels = THCudaTensor_size(state, bottom_grad, 1); + + cudaStream_t stream = THCState_getCurrentStream(state); + + //Backward Kernel + return PSAlignPoolBackwardLauncher( + top_grad_flat, + mapping_channel_flat, argmax_position_flat, + batch_size, + num_rois, spatial_scale, num_channels, + data_height, data_width, + pooled_height, pooled_width, + sample_height, sample_width, + output_dim, bottom_grad_flat, + rois_flat, stream); +} + diff --git a/lib/model/psroi_align_pooling/src/psroi_align_pooling_cuda.h b/lib/model/psroi_align_pooling/src/psroi_align_pooling_cuda.h new file mode 100644 index 00000000..7ee79823 --- /dev/null +++ b/lib/model/psroi_align_pooling/src/psroi_align_pooling_cuda.h @@ -0,0 +1,23 @@ + +int psroi_align_pooling_forward_cuda( + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float spatial_scale, + const int group_size, const int output_dim, + THCudaTensor* features, + THCudaTensor* rois, + THCudaTensor* output, + THCudaIntTensor* mapping_channel, + THCudaIntTensor* argmax_position); + +int psroi_align_pooling_backward_cuda( + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float spatial_scale, + const int group_size, const int output_dim, + THCudaTensor* top_grad, + THCudaTensor* rois, + THCudaTensor* bottom_grad, + THCudaIntTensor* mapping_channel, + THCudaIntTensor* argmax_position); + diff --git a/lib/model/psroi_align_pooling/src/psroi_align_pooling_kernel.cu.cc b/lib/model/psroi_align_pooling/src/psroi_align_pooling_kernel.cu.cc new file mode 100644 index 00000000..f36efb72 --- /dev/null +++ b/lib/model/psroi_align_pooling/src/psroi_align_pooling_kernel.cu.cc @@ -0,0 +1,294 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +//using std::max; +//using std::min; + +__device__ static float ROIAlignGetCoeff(float dh, float dw) +{ + dw = dw > 0 ? dw : -dw; + dh = dh > 0 ? dh : -dh; + return (1.0f - dh) * (1.0f - dw); +} +//bilinear interpolation +__device__ static float ROIAlignGetInterpolating(const float* data, const float h, + const float w, const int height, const int width, const int channels) +{ + float retVal = 0.0f; + int h1 = floorf(h); + int w1 = floorf(w); + bool overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + retVal += overflow ? 0.0f : data[(h1 * width + w1) * channels] * ROIAlignGetCoeff(h - float(h1), w - float(w1)); + h1 = ceilf(h); + w1 = floorf(w); + overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + retVal += overflow? 0.0f : data[(h1 * width + w1) * channels] * ROIAlignGetCoeff(h - float(h1), w - float(w1)); + h1 = floorf(h); + w1 = ceilf(w); + overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + retVal += overflow? 0.0f : data[(h1 * width + w1) * channels] * ROIAlignGetCoeff(h - float(h1), w - float(w1)); + h1 = ceilf(h); + w1 = ceilf(w); + overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + retVal += overflow? 0.0f : data[(h1 * width + w1) * channels] * ROIAlignGetCoeff(h - float(h1), w - float(w1)); + return retVal; +} +//the derivative of bilinear interpolation +__device__ static void ROIAlignDistributeDiff(float* diff, const float top_diff, + const float h, const float w, const int height, const int width, + const int channels) +{ + int h1 = floorf(h); + int w1 = floorf(w); + bool overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + if (!overflow) + { + atomicAdd(diff + (h1 * width + w1) * channels, + top_diff * ROIAlignGetCoeff(h - float(h1), w - float(w1))); + } + h1 = ceilf(h); + w1 = floorf(w); + overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + if (!overflow) + { + atomicAdd(diff + (h1 * width + w1) * channels, + top_diff * ROIAlignGetCoeff(h - float(h1), w - float(w1))); + } + h1 = floorf(h); + w1 = ceilf(w); + overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + if (!overflow) + { + atomicAdd(diff + (h1 * width + w1) * channels, + top_diff * ROIAlignGetCoeff(h - float(h1), w - float(w1))); + } + h1 = ceilf(h); + w1 = ceilf(w); + overflow = (h1<0) || (w1<0) || (h1 >=height) || (w1>=width); + if (!overflow) + { + atomicAdd(diff + (h1 * width + w1) * channels, + top_diff * ROIAlignGetCoeff(h - float(h1), w - float(w1))); + } +} + +__global__ void PSAlignPoolingForward( + const int nthreads, + const float* bottom_data, + const float spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float* bottom_rois, + const int output_dim, + const int group_size, + float* top_data, + int* mapping_channel, + int* argmax_position) { + CUDA_KERNEL_LOOP(index, nthreads) + { + int n = index; + int pw = n % pooled_width; + n /= pooled_width; + int ph = n % pooled_height; + n /= pooled_height; + int ctop = n % output_dim; + n /= output_dim; + + + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + + float roi_start_w = static_cast(bottom_rois[1]) * spatial_scale; + float roi_start_h = static_cast(bottom_rois[2]) * spatial_scale; + float roi_end_w = static_cast(bottom_rois[3]) * spatial_scale; + float roi_end_h = static_cast(bottom_rois[4]) * spatial_scale; + + + float roi_width = max(roi_end_w - roi_start_w, 0.0); + float roi_height = max(roi_end_h - roi_start_h, 0.0); + + // Compute w and h at bottom + float bin_size_h = roi_height / static_cast(pooled_height); + float bin_size_w = roi_width / static_cast(pooled_width); + + //ps align max pooling + bottom_data += roi_batch_ind * channels * height * width; + float sample_h_rate = 1.0f / float(sample_height); + float sample_w_rate = 1.0f / float(sample_width); + float hcenter; + float wcenter; + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph)* group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + int c = (ctop*group_size + gh)*group_size + gw; + + float tmp = float(-1e20); + float tmp2; + int buf_value = -1; + for (int h_iter = 0; h_iter < sample_height; ++h_iter) + { + for (int w_iter = 0; w_iter < sample_width; ++w_iter) + { + hcenter = roi_start_h + bin_size_h * (ph + sample_h_rate * (h_iter + 0.5f)); + wcenter = roi_start_w + bin_size_w * (pw + sample_w_rate * (w_iter + 0.5f)); + tmp2 = ROIAlignGetInterpolating( + bottom_data + c, hcenter, wcenter, height, width, channels); + if (tmp2 > tmp) + { + tmp = tmp2; + buf_value = w_iter + h_iter * sample_width; + } + } + } + top_data[index] = tmp; + argmax_position[index] = buf_value; + mapping_channel[index] = c; + } +} +int PSAlignPoolForwardLauncher( + const float* bottom_data, const float spatial_scale, const int num_rois, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float* bottom_rois, const int output_dim, + const int group_size, float* top_data, + int* mapping_channel, int* argmax_position, + cudaStream_t stream) +{ + const int kThreadsPerBlock = 1024; + const int output_size = num_rois * pooled_height * pooled_width * output_dim; + cudaError_t err; + + PSAlignPoolingForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, bottom_data, spatial_scale, channels, height, width, + pooled_height, pooled_width, sample_height, sample_width, + bottom_rois, output_dim, group_size, + top_data, mapping_channel, argmax_position); + + err = cudaGetLastError(); + if(cudaSuccess != err) + { + fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); + exit( -1 ); + } + + return 1; +} + +__global__ void PSAlignPoolingBackwardAtomic( + const int nthreads, + const float* top_diff, + const int* mapping_channel, + const int* argmax_position, + const int num_rois, + const float spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const int output_dim, + float* bottom_diff, + const float* bottom_rois) +{ + CUDA_KERNEL_LOOP(index, nthreads) + { + // The output is in order (n, ctop, ph, pw) + + /* + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / output_dim; + */ + + int n = index; + int pw = n % pooled_width; + n /= pooled_width; + int ph = n % pooled_height; + n /= pooled_height; + //int ctop = n % output_dim; + n /= output_dim; + + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + float roi_start_w = static_cast(bottom_rois[1]) * spatial_scale; + float roi_start_h = static_cast(bottom_rois[2]) * spatial_scale; + float roi_end_w = static_cast(bottom_rois[3]) * spatial_scale; + float roi_end_h = static_cast(bottom_rois[4]) * spatial_scale; + + float roi_width = max(roi_end_w - roi_start_w, (float)0); + float roi_height = max(roi_end_h - roi_start_h, (float)0); + + float bin_size_h = roi_height / static_cast(pooled_height); + float bin_size_w = roi_width / static_cast(pooled_width); + + /*new roi align*/ + int c = mapping_channel[index]; + bottom_diff += roi_batch_ind * channels * height * width; + + float sample_h_rate = 1.0f / float(sample_height); + float sample_w_rate = 1.0f / float(sample_width); + + float tmp = top_diff[index]; + int buffer_value = argmax_position[index]; + int w_iter = buffer_value % sample_width; + int h_iter = buffer_value / sample_width; + float hcenter = roi_start_h + bin_size_h * (ph + sample_h_rate * (h_iter + 0.5f)); + float wcenter = roi_start_w + bin_size_w * (pw + sample_w_rate * (w_iter + 0.5f)); + ROIAlignDistributeDiff(bottom_diff + c, tmp, hcenter, + wcenter, height, width, channels); + + } +} + +int PSAlignPoolBackwardLauncher( + const float* top_diff, + const int* mapping_channel, const int* argmax_position, + const int batch_size, + const int num_rois, const float spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const int output_dim, float* bottom_diff, + const float* bottom_rois, cudaStream_t stream) +{ + const int kThreadsPerBlock = 1024; + const int output_size = num_rois * pooled_height * pooled_width * output_dim; + const int bottom_size = batch_size * height * width * channels; + cudaError_t err; + + cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_size, stream); + +// SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, +// kThreadsPerBlock, 0, stream>>>(bottom_size, bottom_diff); + + PSAlignPoolingBackwardAtomic<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, stream>>>( + output_size, top_diff, mapping_channel, argmax_position, + num_rois, spatial_scale, channels, height, width, + pooled_height, pooled_width, sample_height, sample_width, + output_dim, bottom_diff, bottom_rois); + + err = cudaGetLastError(); + if(cudaSuccess != err) + { + fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); + exit( -1 ); + } + + return 1; +} + + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/lib/model/psroi_align_pooling/src/psroi_align_pooling_kernel.h b/lib/model/psroi_align_pooling/src/psroi_align_pooling_kernel.h new file mode 100644 index 00000000..926e32ce --- /dev/null +++ b/lib/model/psroi_align_pooling/src/psroi_align_pooling_kernel.h @@ -0,0 +1,39 @@ +#ifndef _PSROI_ALIGN_KERNEL +#define _PSROI_ALIGN_KERNEL + +#ifdef __cplusplus +extern "C" { +#endif + + +int PSAlignPoolForwardLauncher( + const float* bottom_data, const float spatial_scale, const int num_rois, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const float* bottom_rois, const int output_dim, + const int group_size, float* top_data, + int* mapping_channel, int* argmax_position, + cudaStream_t stream); + +int PSAlignPoolBackwardLauncher( + const float* top_diff, + const int* mapping_channel, const int* argmax_position, + const int batch_size, + const int num_rois, const float spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sample_height, const int sample_width, + const int output_dim, float* bottom_diff, + const float* bottom_rois, cudaStream_t stream); + + + + + + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/lib/model/psroi_pooling/__init__.py b/lib/model/psroi_pooling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_pooling/_ext/__init__.py b/lib/model/psroi_pooling/_ext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_pooling/_ext/psroi_pooling/__init__.py b/lib/model/psroi_pooling/_ext/psroi_pooling/__init__.py new file mode 100644 index 00000000..390ceb66 --- /dev/null +++ b/lib/model/psroi_pooling/_ext/psroi_pooling/__init__.py @@ -0,0 +1,15 @@ + +from torch.utils.ffi import _wrap_function +from ._psroi_pooling import lib as _lib, ffi as _ffi + +__all__ = [] +def _import_symbols(locals): + for symbol in dir(_lib): + fn = getattr(_lib, symbol) + if callable(fn): + locals[symbol] = _wrap_function(fn, _ffi) + else: + locals[symbol] = fn + __all__.append(symbol) + +_import_symbols(locals()) diff --git a/lib/model/psroi_pooling/build.py b/lib/model/psroi_pooling/build.py new file mode 100644 index 00000000..2d7648c4 --- /dev/null +++ b/lib/model/psroi_pooling/build.py @@ -0,0 +1,33 @@ +import os +import torch +from torch.utils.ffi import create_extension + +sources = [] +headers = [] +defines = [] +with_cuda = False + +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/psroi_pooling_cuda.c'] + headers += ['src/psroi_pooling_cuda.h'] + defines += [('WITH_CUDA', None)] + with_cuda = True + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +extra_objects = ['src/psroi_pooling_kernel.cu.o'] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.psroi_pooling', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects +) + +if __name__ == '__main__': + ffi.build() diff --git a/lib/model/psroi_pooling/functions/__init__.py b/lib/model/psroi_pooling/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_pooling/functions/psroi_pooling.py b/lib/model/psroi_pooling/functions/psroi_pooling.py new file mode 100644 index 00000000..d50d4999 --- /dev/null +++ b/lib/model/psroi_pooling/functions/psroi_pooling.py @@ -0,0 +1,58 @@ +import torch +from torch.autograd import Function +from .._ext import psroi_pooling + + +class PSRoIPoolingFunction(Function): + def __init__(self, pooled_height, pooled_width, spatial_scale, group_size): + self.pooled_width = int(pooled_width) + self.pooled_height = int(pooled_height) + self.spatial_scale = float(spatial_scale) + self.group_size = int(group_size) + self.output = None + self.mappingchannel = None + self.rois = None + self.feature_size = None + self.output_dim = None + + def forward(self, features, rois): + batch_size, num_channels, data_height, data_width = features.size() + self.output_dim = num_channels // self.pooled_height // self.pooled_width + num_rois = rois.size()[0] + output = torch.zeros(num_rois, self.output_dim, self.pooled_height, self.pooled_width) + mappingchannel = torch.IntTensor(num_rois, self.output_dim, self.pooled_height, self.pooled_width).zero_() + output = output.cuda() + mappingchannel = mappingchannel.cuda() + psroi_pooling.psroi_pooling_forward_cuda(self.pooled_height, + self.pooled_width, + self.spatial_scale, + self.group_size, + self.output_dim, + features, + rois, + output, + mappingchannel) + self.output = output + self.mappingchannel = mappingchannel + self.rois = rois + self.feature_size = features.size() + + return output + + def backward(self, grad_output): + + assert(self.feature_size is not None and grad_output.is_cuda) + + batch_size, num_channels, data_height, data_width = self.feature_size + + grad_input = torch.zeros(batch_size, num_channels, data_height, data_width).cuda() + + psroi_pooling.psroi_pooling_backward_cuda(self.pooled_height, + self.pooled_width, + self.spatial_scale, + self.output_dim, + grad_output, + self.rois, + grad_input, + self.mappingchannel) + return grad_input, None diff --git a/lib/model/psroi_pooling/make.sh b/lib/model/psroi_pooling/make.sh new file mode 100755 index 00000000..375df734 --- /dev/null +++ b/lib/model/psroi_pooling/make.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +CUDA_PATH=/usr/local/cuda/ + +cd src +echo "Compiling my_lib kernels by nvcc..." +nvcc -c -o psroi_pooling_kernel.cu.o psroi_pooling_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 + +cd ../ +python build.py diff --git a/lib/model/psroi_pooling/modules/__init__.py b/lib/model/psroi_pooling/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/model/psroi_pooling/modules/psroi_pool.py b/lib/model/psroi_pooling/modules/psroi_pool.py new file mode 100644 index 00000000..a00d4b2e --- /dev/null +++ b/lib/model/psroi_pooling/modules/psroi_pool.py @@ -0,0 +1,17 @@ +from torch.nn.modules.module import Module +import sys +from ..functions.psroi_pooling import PSRoIPoolingFunction + + +class PSRoIPool(Module): + def __init__(self, pooled_height, pooled_width, spatial_scale, group_size, output_dim): + super(PSRoIPool, self).__init__() + + self.pooled_width = int(pooled_width) + self.pooled_height = int(pooled_height) + self.spatial_scale = float(spatial_scale) + self.group_size = int(group_size) + self.output_dim = int(output_dim) + + def forward(self, features, rois): + return PSRoIPoolingFunction(self.pooled_height, self.pooled_width, self.spatial_scale, self.group_size, self.output_dim)(features, rois) diff --git a/lib/model/psroi_pooling/src/psroi_pooling_cuda.c b/lib/model/psroi_pooling/src/psroi_pooling_cuda.c new file mode 100644 index 00000000..4250d108 --- /dev/null +++ b/lib/model/psroi_pooling/src/psroi_pooling_cuda.c @@ -0,0 +1,75 @@ +#include +#include +#include "psroi_pooling_kernel.h" + + + +extern THCState* state; + +int psroi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale, int group_size, int output_dim,THCudaTensor *features, THCudaTensor* rois, THCudaTensor* output, THCudaIntTensor* mappingchannel){ + float* data_in = THCudaTensor_data(state, features); + float* rois_in = THCudaTensor_data(state, rois); + float* output_out = THCudaTensor_data(state, output); + int* mappingchannel_out = THCudaIntTensor_data(state, mappingchannel); + //Get # of Rois + int num_rois = THCudaTensor_size(state, rois, 0); + int size_rois = THCudaTensor_size(state, rois, 1); + if (size_rois!=5) + { + return 0; + } + + //Get # of batch_size + int batch_size = THCudaTensor_size(state, features, 0); + if (batch_size!=1) + { + return 0; + } + + int data_height = THCudaTensor_size(state, features, 2); + int data_width = THCudaTensor_size(state, features, 3); + int num_channels = THCudaTensor_size(state, features, 1); + + cudaStream_t stream = THCState_getCurrentStream(state); + + // call the gpu kernel for psroi_pooling + PSROIPoolForwardLauncher(data_in, spatial_scale, num_rois, data_height, data_width, num_channels, pooled_height, pooled_width,rois_in, group_size, + output_dim, output_out, mappingchannel_out,stream); + return 1; +} + + +int psroi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale, int output_dim, +THCudaTensor* top_grad, THCudaTensor* rois, THCudaTensor* bottom_grad, THCudaIntTensor* mappingchannel) +{ + float *top_grad_flat = THCudaTensor_data(state, top_grad); + float *rois_flat = THCudaTensor_data(state, rois); + + float *bottom_grad_flat = THCudaTensor_data(state, bottom_grad); + int *mappingchannel_flat = THCudaIntTensor_data(state, mappingchannel); + + // Number of ROIs + int num_rois = THCudaTensor_size(state, rois, 0); + int size_rois = THCudaTensor_size(state, rois, 1); + if (size_rois != 5) + { + return 0; + } + // batch size + int batch_size = THCudaTensor_size(state, bottom_grad, 0); + if (batch_size != 1) + { + return 0; + } + // data height + int data_height = THCudaTensor_size(state, bottom_grad, 2); + // data width + int data_width = THCudaTensor_size(state, bottom_grad, 3); + // Number of channels + int num_channels = THCudaTensor_size(state, bottom_grad, 1); + + cudaStream_t stream = THCState_getCurrentStream(state); + + PSROIPoolBackwardLauncher(top_grad_flat, mappingchannel_flat, batch_size, num_rois, spatial_scale, num_channels, data_height, data_width, pooled_width, pooled_height, output_dim, bottom_grad_flat, rois_flat, stream); + return 1; +} diff --git a/lib/model/psroi_pooling/src/psroi_pooling_cuda.h b/lib/model/psroi_pooling/src/psroi_pooling_cuda.h new file mode 100644 index 00000000..65d8bf07 --- /dev/null +++ b/lib/model/psroi_pooling/src/psroi_pooling_cuda.h @@ -0,0 +1,5 @@ +int psroi_pooling_forward_cuda( int pooled_height, int pooled_width, float spatial_scale,int group_size, int output_dim, + THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * mappingchannel); + +int psroi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale, int output_dim, + THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * mappingchannel); diff --git a/lib/model/psroi_pooling/src/psroi_pooling_kernel.cu b/lib/model/psroi_pooling/src/psroi_pooling_kernel.cu new file mode 100644 index 00000000..1f26d97e --- /dev/null +++ b/lib/model/psroi_pooling/src/psroi_pooling_kernel.cu @@ -0,0 +1,201 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include "psroi_pooling_kernel.h" + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +__global__ void PSROIPoolForward(const int nthreads, const float* bottom_data, + const float spatial_scale, const int height, const int width, + const int channels, const int pooled_height, const int pooled_width, + const int group_size, const int output_dim, + const float* bottom_rois, float* top_data, int* mapping_channel) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + float roi_start_w = + static_cast(round(bottom_rois[1])) * spatial_scale; + float roi_start_h = + static_cast(round(bottom_rois[2])) * spatial_scale; + float roi_end_w = + static_cast(round(bottom_rois[3]) + 1.) * spatial_scale; + float roi_end_h = + static_cast(round(bottom_rois[4]) + 1.) * spatial_scale; + + // Force malformed ROIs to be 1x1 + float roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 + float roi_height = max(roi_end_h - roi_start_h, 0.1); + + float bin_size_h = (float)(roi_height) / (float)(pooled_height); + float bin_size_w = (float)(roi_width) / (float)(pooled_width); + + int hstart = floor(static_cast(ph) * bin_size_h + + roi_start_h); + int wstart = floor(static_cast(pw)* bin_size_w + + roi_start_w); + int hend = ceil(static_cast(ph + 1) * bin_size_h + + roi_start_h); + int wend = ceil(static_cast(pw + 1) * bin_size_w + + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph)* group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + int c = (ctop*group_size + gh)*group_size + gw; + + bottom_data += (roi_batch_ind * channels + c) * height * width; + float out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h*width + w; + out_sum += bottom_data[bottom_index]; + } + } + float bin_area = (hend - hstart)*(wend - wstart); + top_data[index] = is_empty? 0. : out_sum/bin_area; + mapping_channel[index] = c; + } +} + + +int PSROIPoolForwardLauncher( + const float* bottom_data, const float spatial_scale, const int num_rois, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const float* bottom_rois, + const int group_size, const int output_dim, + float* top_data, int* mapping_channel, cudaStream_t stream) +{ + const int kThreadsPerBlock = 1024; + const int output_size = output_dim * pooled_height * pooled_width * num_rois; + cudaError_t err; + + + PSROIPoolForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( + output_size, bottom_data, spatial_scale, height, width, channels, pooled_height, + pooled_width, group_size, output_dim, bottom_rois, top_data, mapping_channel); + + err = cudaGetLastError(); + if(cudaSuccess != err) + { + fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); + exit( -1 ); + } + + return 1; +} + + +__global__ void PSROIPoolBackward(const int nthreads, const float* top_diff, + const int* mapping_channel, const int num_rois, const float spatial_scale, + const int height, const int width, const int channels, + const int pooled_height, const int pooled_width, const int output_dim, float* bottom_diff, + const float* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + float roi_start_w = + static_cast(round(bottom_rois[1])) * spatial_scale; + float roi_start_h = + static_cast(round(bottom_rois[2])) * spatial_scale; + float roi_end_w = + static_cast(round(bottom_rois[3]) + 1.) * spatial_scale; + float roi_end_h = + static_cast(round(bottom_rois[4]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + float roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 + float roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + float bin_size_h = roi_height / static_cast(pooled_height); + float bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(static_cast(ph)* bin_size_h + + roi_start_h); + int wstart = floor(static_cast(pw)* bin_size_w + + roi_start_w); + int hend = ceil(static_cast(ph + 1) * bin_size_h + + roi_start_h); + int wend = ceil(static_cast(pw + 1) * bin_size_w + + roi_start_w); + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Compute c at bottom + int c = mapping_channel[index]; + float* offset_bottom_diff = bottom_diff + + (roi_batch_ind * channels + c) * height * width; + float bin_area = (hend - hstart)*(wend - wstart); + float diff_val = is_empty ? 0. : top_diff[index] / bin_area; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h*width + w; + //caffe_gpu_atomic_add(diff_val, offset_bottom_diff + bottom_index); + atomicAdd(offset_bottom_diff + bottom_index, diff_val); + } + } + } +} + +int PSROIPoolBackwardLauncher(const float* top_diff, const int* mapping_channel, const int batch_size, const int num_rois, const float spatial_scale, const int channels, + const int height, const int width, const int pooled_width, + const int pooled_height, const int output_dim, + float* bottom_diff, const float* bottom_rois, cudaStream_t stream) +{ + const int kThreadsPerBlock = 1024; + //const int output_size = output_dim * height * width * channels; + const int output_size = output_dim * pooled_height * pooled_width * num_rois; + cudaError_t err; + + PSROIPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( + output_size, top_diff, mapping_channel, num_rois, spatial_scale, height, width, channels, pooled_height, + pooled_width, output_dim, bottom_diff, bottom_rois); + + err = cudaGetLastError(); + if(cudaSuccess != err) + { + fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); + exit( -1 ); + } + + return 1; +} + + +#ifdef __cplusplus +} +#endif diff --git a/lib/model/psroi_pooling/src/psroi_pooling_kernel.h b/lib/model/psroi_pooling/src/psroi_pooling_kernel.h new file mode 100644 index 00000000..03a07f52 --- /dev/null +++ b/lib/model/psroi_pooling/src/psroi_pooling_kernel.h @@ -0,0 +1,21 @@ +#ifndef PS_ROI_POOLING_KERNEL +#define PS_ROI_POOLING_KERNEL + +#ifdef __cplusplus +extern "C" { +#endif + +int PSROIPoolForwardLauncher( + const float* bottom_data, const float spatial_scale, const int num_rois, const int height, + const int width, const int channels, const int pooled_height, const int pooled_width, + const float* bottom_rois, const int group_size, const int output_dim, float* top_data, int* mapping_channel, cudaStream_t stream); + + +int PSROIPoolBackwardLauncher(const float* top_diff, const int* mapping_channel, const int batch_size, const int num_rois, const float spatial_scale, const int channels, const int height, const int width, const int pooled_width, const int pooled_height, const int output_dim, float* bottom_diff, const float* bottom_rois, cudaStream_t stream); + +#ifdef __cplusplus +} + +#endif + +#endif diff --git a/lib/modeling/ResNet.py b/lib/modeling/ResNet.py index 311d1c83..fb173577 100644 --- a/lib/modeling/ResNet.py +++ b/lib/modeling/ResNet.py @@ -134,7 +134,7 @@ def detectron_weight_mapping(self): residual_stage_detectron_mapping(self.res5, 'res5', 3, 5) return mapping_to_detectron, orphan_in_detectron - def forward(self, x, rpn_ret): + def forward(self, x, rpn_ret):#hw x = self.roi_xform( x, rpn_ret, blob_rois='rois', diff --git a/lib/modeling/fast_rcnn_heads.py b/lib/modeling/fast_rcnn_heads.py index cc4f55d5..af03080f 100644 --- a/lib/modeling/fast_rcnn_heads.py +++ b/lib/modeling/fast_rcnn_heads.py @@ -13,8 +13,8 @@ class fast_rcnn_outputs(nn.Module): def __init__(self, dim_in): super().__init__() self.cls_score = nn.Linear(dim_in, cfg.MODEL.NUM_CLASSES) - if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG: - self.bbox_pred = nn.Linear(dim_in, 4) + if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG:# bg and fg + self.bbox_pred = nn.Linear(dim_in, 4*2) else: self.bbox_pred = nn.Linear(dim_in, 4 * cfg.MODEL.NUM_CLASSES) diff --git a/lib/modeling/light_head_rcnn_heads.py b/lib/modeling/light_head_rcnn_heads.py new file mode 100644 index 00000000..2adfac29 --- /dev/null +++ b/lib/modeling/light_head_rcnn_heads.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from torch.autograd import Variable + +from core.config import cfg +import utils.net as net_utils + +# # light head +# conv_new_1 = mx.sym.Convolution(data=relu1, kernel=(15, 1), pad=(7, 0), num_filter=256, name="conv_new_1", lr_mult=3.0) +# relu_new_1 = mx.sym.Activation(data=conv_new_1, act_type='relu', name='relu1') +# conv_new_2 = mx.sym.Convolution(data=relu_new_1, kernel=(1, 15), pad=(0, 7), num_filter=10 * 7 * 7, name="conv_new_2", +# lr_mult=3.0) +# relu_new_2 = mx.sym.Activation(data=conv_new_2, act_type='relu', name='relu2') +# conv_new_3 = mx.sym.Convolution(data=relu1, kernel=(1, 15), pad=(0, 7), num_filter=256, name="conv_new_3", lr_mult=3.0) +# relu_new_3 = mx.sym.Activation(data=conv_new_3, act_type='relu', name='relu3') +# conv_new_4 = mx.sym.Convolution(data=relu_new_3, kernel=(15, 1), pad=(7, 0), num_filter=10 * 7 * 7, name="conv_new_4", +# lr_mult=3.0) +# relu_new_4 = mx.sym.Activation(data=conv_new_4, act_type='relu', name='relu4') +# light_head = mx.symbol.broadcast_add(name='light_head', *[relu_new_2, relu_new_4]) +# roi_pool = mx.contrib.sym.PSROIPooling(name='roi_pool', data=light_head, rois=rois, group_size=7, pooled_size=7, +# output_dim=10, spatial_scale=0.0625) +# fc_new_1 = mx.symbol.FullyConnected(name='fc_new_1', data=roi_pool, num_hidden=2048) +# fc_new_1_relu = mx.sym.Activation(data=fc_new_1, act_type='relu', name='fc_new_1_relu') +# cls_score = mx.symbol.FullyConnected(name='cls_score', data=fc_new_1_relu, num_hidden=num_classes) +# bbox_pred = mx.symbol.FullyConnected(name='bbox_pred', data=fc_new_1_relu, num_hidden=num_reg_classes * 4) +from .ResNet import add_stage,freeze_params,mynn,residual_stage_detectron_mapping +class large_separable_conv(nn.Module): + def __init__(self,chl_in, ks=15, chl_mid=256, chl_out=1024): + super().__init__() + pad=(ks-1)//2 + self.col_max = nn.Conv2d(chl_in,chl_mid,(ks,1),padding=(pad,0)) + self.col = nn.Conv2d(chl_mid,chl_out,(1,ks),padding=(0,pad)) + self.row_max = nn.Conv2d(chl_in,chl_mid,(1,ks),padding=(pad,0)) + self.row = nn.Conv2d(chl_mid,chl_out,(ks,1),padding=(0,pad)) + + def detectron_weight_mapping(self): + detectron_weight_mapping = { + 'xconv.col_max.weight':'col_max_w', + 'xconv.col_max.bias': 'col_max_b', + 'xconv.col.weight': 'col_w', + 'xconv.col.bias': 'col_b', + 'xconv.row_max.weight': 'row_max_w', + 'xconv.row_max.bias': 'row_max_b', + 'xconv.row.weight': 'row_w', + 'xconv.row.bias': 'row_b', + } + return detectron_weight_mapping,[] + + def forward(self, x): + y1 = self.col(self.col_max(x)) + y2 = self.row(self.row_max(x)) + return y1+y2 + + +class ResNet_Conv5_light_head(nn.Module): + def __init__(self,dim_in): + super().__init__() + dim_bottleneck = cfg.RESNETS.NUM_GROUPS * cfg.RESNETS.WIDTH_PER_GROUP + stride_init = cfg.LIGHT_HEAD_RCNN.ROI_XFORM_RESOLUTION // 7 + self.res5, self.dim_out = add_stage(dim_in, dim_bottleneck * 8, 3, stride_init) + assert self.dim_out == 2048 + self.ps_chl = 7 * 7 * 10 + self.xconv = large_separable_conv(chl_in=self.dim_out, ks=15, chl_mid=256, chl_out=self.ps_chl) + self._init_modules() + + def _init_modules(self): + # Freeze all bn (affine) layers !!! + self.apply(lambda m: freeze_params(m) if isinstance(m, mynn.AffineChannel2d) else None) + + def detectron_weight_mapping(self): + mapping_to_detectron, orphan_in_detectron = \ + residual_stage_detectron_mapping(self.res5, 'res5', 3, 5) + return mapping_to_detectron, orphan_in_detectron + + def forward(self,x): + res5_feat = self.res5(x) + return self.xconv(res5_feat) + + + + +class light_head_rcnn_outputs(nn.Module): + def __init__(self, dim_in,roi_xform_func,spatial_scale): + super().__init__() + self.ps_fc_1 = nn.Linear(dim_in, 2048) + self.cls_score = nn.Linear(2048, cfg.MODEL.NUM_CLASSES) + if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG: + self.bbox_pred = nn.Linear(2048, 4) + else: + self.bbox_pred = nn.Linear(2048, 4 * cfg.MODEL.NUM_CLASSES) + self.roi_xform = roi_xform_func + self.spatial_scale = spatial_scale + self._init_weights() + + def _init_weights(self): + init.normal(self.ps_fc_1.weight, std=0.01) + init.constant(self.ps_fc_1.bias, 0) + init.normal(self.cls_score.weight, std=0.01) + init.constant(self.cls_score.bias, 0) + init.normal(self.bbox_pred.weight, std=0.001) + init.constant(self.bbox_pred.bias, 0) + + def detectron_weight_mapping(self): + detectron_weight_mapping = { + 'ps_fc_1.weight': 'ps_fc_1_w', + 'ps_fc_1.bias': 'ps_fc_1_b', + 'cls_score.weight': 'cls_score_w', + 'cls_score.bias': 'cls_score_b', + 'bbox_pred.weight': 'bbox_pred_w', + 'bbox_pred.bias': 'bbox_pred_b' + } + orphan_in_detectron = [] + return detectron_weight_mapping, orphan_in_detectron + + def forward(self, x, rpn_ret): + x = self.roi_xform( + x, rpn_ret, + blob_rois='rois', + method=cfg.LIGHT_HEAD_RCNN.ROI_XFORM_METHOD, + resolution=cfg.LIGHT_HEAD_RCNN.ROI_XFORM_RESOLUTION, + spatial_scale=self.spatial_scale + ) + x = F.relu(self.ps_fc_1(x.view(x.size(0),-1)),inplace=True) + cls_score = self.cls_score(x) + if not self.training: + cls_score = F.softmax(cls_score, dim=1) + bbox_pred = self.bbox_pred(x) + + return cls_score, bbox_pred + + +def light_head_rcnn_losses(cls_score, bbox_pred, label_int32, bbox_targets, + bbox_inside_weights, bbox_outside_weights): + device_id = cls_score.get_device() + rois_label = Variable(torch.from_numpy(label_int32.astype('int64'))).cuda(device_id) + # logits_rois_label = Variable(torch.cuda.IntTensor(cls_score.size()[0],cls_score.size()[1]).zero_()) + # for i, val in enumerate(rois_label.data): + # logits_rois_label.data[i, val] = 1 if val > 0 else 0 + loss_cls =F.cross_entropy(cls_score, rois_label)# F.binary_cross_entropy_with_logits(cls_score,logits_rois_label) + bbox_targets = Variable(torch.from_numpy(bbox_targets)).cuda(device_id) + bbox_inside_weights = Variable(torch.from_numpy(bbox_inside_weights)).cuda(device_id) + bbox_outside_weights = Variable(torch.from_numpy(bbox_outside_weights)).cuda(device_id) + loss_bbox = net_utils.smooth_l1_loss( + bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights) + return loss_cls, loss_bbox*5 + diff --git a/lib/modeling/model_builder.py b/lib/modeling/model_builder.py index ab20387b..ff8d29a7 100644 --- a/lib/modeling/model_builder.py +++ b/lib/modeling/model_builder.py @@ -9,10 +9,13 @@ from core.config import cfg from model.roi_pooling.functions.roi_pool import RoIPoolFunction from model.roi_crop.functions.roi_crop import RoICropFunction +from model.psroi_pooling.functions.psroi_pooling import PSRoIPoolingFunction #hw +from model.psroi_align_pooling.functions.psroi_align_pooling import PSRoiAlignPoolingFunction #hw from modeling.roi_xfrom.roi_align.functions.roi_align import RoIAlignFunction import modeling.rpn_heads as rpn_heads import modeling.fast_rcnn_heads as fast_rcnn_heads import modeling.mask_rcnn_heads as mask_rcnn_heads +import modeling.light_head_rcnn_heads as light_head_rcnn_heads import modeling.keypoint_rcnn_heads as keypoint_rcnn_heads import utils.blob as blob_utils import utils.net as net_utils @@ -72,12 +75,19 @@ def __init__(self): # may include extra scales that are used for RPN proposals, but not for RoI heads. self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:] - # BBOX Branch + # BBOX Branch #hw RPN output to ROIPool layer if not cfg.MODEL.RPN_ONLY: - self.Box_Head = get_func(cfg.FAST_RCNN.ROI_BOX_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) - self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs( + if cfg.MODEL.FASTER_RCNN: + self.Box_Head = get_func(cfg.FAST_RCNN.ROI_BOX_HEAD)( + # hw ResNet_roi_conv5_head RPN网络输出roi和prob到ROIAlgin层 + self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) + self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs(#hw 最后输出avgpool并输入到进入全连接层输出box和prob self.Box_Head.dim_out) + elif cfg.MODEL.LIGHT_HEAD_RCNN:#hw + from modeling.light_head_rcnn_heads import ResNet_Conv5_light_head + self.LightHead = ResNet_Conv5_light_head(self.Conv_Body.dim_out) + self.Box_Outs = light_head_rcnn_heads.light_head_rcnn_outputs(490,self.roi_feature_transform,self.Conv_Body.spatial_scale) + # Mask Branch if cfg.MODEL.MASK_ON: @@ -99,7 +109,8 @@ def __init__(self): def _init_modules(self): if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS: - resnet_utils.load_pretrained_imagenet_weights(self) + if cfg.MODEL.FASTER_RCNN:#hw + resnet_utils.load_pretrained_imagenet_weights(self) # Check if shared weights are equaled if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False): assert self.Mask_Head.res5.state_dict() == self.Box_Head.res5.state_dict() @@ -130,9 +141,7 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): rpn_ret = self.RPN(blob_conv, im_info, roidb) - # if self.training: - # # can be used to infer fg/bg ratio - # return_dict['rois_label'] = rpn_ret['labels_int32'] + if cfg.FPN.FPN_ON: # Retain only the blobs that will be used for RoI heads. `blob_conv` may include @@ -145,9 +154,17 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): if not cfg.MODEL.RPN_ONLY: if cfg.MODEL.SHARE_RES5 and self.training: box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret) - else: + elif cfg.MODEL.FASTER_RCNN: #hw FAST_RCNN or LIGHT_HEAD_RCNN网络预测类别和坐标 box_feat = self.Box_Head(blob_conv, rpn_ret) - cls_score, bbox_pred = self.Box_Outs(box_feat) + cls_score, bbox_pred = self.Box_Outs(box_feat) + return_dict['cls_score'] = cls_score + return_dict['bbox_pred'] = bbox_pred + elif cfg.MODEL.LIGHT_HEAD_RCNN: + box_feat = F.relu(self.LightHead(blob_conv), inplace=True) + cls_score, bbox_pred = self.Box_Outs(box_feat,rpn_ret) + return_dict['cls_score'] = cls_score + return_dict['bbox_pred'] = bbox_pred + else: # TODO: complete the returns for RPN only situation pass @@ -159,7 +176,7 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): rpn_kwargs.update(dict( (k, rpn_ret[k]) for k in rpn_ret.keys() if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred')) - )) + )) #hw RPN网络类别和坐标的损失值计算 loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs) if cfg.FPN.FPN_ON: for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): @@ -169,13 +186,23 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox - # bbox loss - loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses( + # bbox loss #hw RCNN网络类别和坐标的损失值计算 + if cfg.MODEL.FASTER_RCNN: + # bbox loss + loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses( + cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'], + rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights']) + return_dict['losses']['loss_cls'] = loss_cls + return_dict['losses']['loss_bbox'] = loss_bbox + return_dict['metrics']['accuracy_cls'] = accuracy_cls + elif cfg.MODEL.LIGHT_HEAD_RCNN: # hw + loss_cls, loss_bbox = light_head_rcnn_heads.light_head_rcnn_losses( cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights']) - return_dict['losses']['loss_cls'] = loss_cls - return_dict['losses']['loss_bbox'] = loss_bbox - return_dict['metrics']['accuracy_cls'] = accuracy_cls + return_dict['losses']['loss_cls'] = loss_cls + return_dict['losses']['loss_bbox'] = loss_bbox + + if cfg.MODEL.MASK_ON: if getattr(self.Mask_Head, 'SHARE_RES5', False): @@ -232,7 +259,7 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI - Use of FPN or not - Specifics of the transform method """ - assert method in {'RoIPoolF', 'RoICrop', 'RoIAlign'}, \ + assert method in {'RoIPoolF', 'RoICrop', 'RoIAlign','PSRoIPool','PSRoIAlignPool'}, \ 'Unknown pooling method: {}'.format(method) if isinstance(blobs_in, list): @@ -263,8 +290,16 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI elif method == 'RoIAlign': xform_out = RoIAlignFunction( resolution, resolution, sc, sampling_ratio)(bl_in, rois) + elif method == 'PSRoIPool':#hw add Light-Head-RCNN + xform_out = PSRoIPoolingFunction( + resolution, resolution, sc,resolution)(bl_in, rois) + elif method == 'PSRoIAlignPool': # hw add Light-Head-RCNN + xform_out = PSRoiAlignPoolingFunction( + resolution, resolution,2,2, sc, resolution)(bl_in, rois) bl_out_list.append(xform_out) + + # The pooled features from all levels are concatenated along the # batch dimension into a single 4D tensor. xform_shuffled = torch.cat(bl_out_list, dim=0) @@ -294,6 +329,12 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI elif method == 'RoIAlign': xform_out = RoIAlignFunction( resolution, resolution, spatial_scale, sampling_ratio)(blobs_in, rois) + elif method == 'PSRoIPool': # hw add Light-Head-RCNN + xform_out = PSRoIPoolingFunction( + resolution, resolution, spatial_scale, resolution)(blobs_in, rois) + elif method == 'PSRoIAlignPool': # hw add Light-Head-RCNN + xform_out = PSRoiAlignPoolingFunction( + resolution, resolution, 2, 2, spatial_scale, resolution)(blobs_in, rois) return xform_out diff --git a/lib/modeling/rpn_heads.py b/lib/modeling/rpn_heads.py index 5edcf3a7..7e78f8a1 100644 --- a/lib/modeling/rpn_heads.py +++ b/lib/modeling/rpn_heads.py @@ -58,6 +58,7 @@ def __init__(self, dim_in, spatial_scale): self.RPN_GenerateProposals = GenerateProposalsOp(anchors, spatial_scale) self.RPN_GenerateProposalLabels = GenerateProposalLabelsOp() + self._init_weights() def _init_weights(self): @@ -95,7 +96,7 @@ def forward(self, x, im_info, roidb=None): return_dict = { 'rpn_cls_logits': rpn_cls_logits, 'rpn_bbox_pred': rpn_bbox_pred} - if not self.training or cfg.MODEL.FASTER_RCNN: + if not self.training or cfg.MODEL.FASTER_RCNN or cfg.MODEL.LIGHT_HEAD_RCNN: # Proposals are needed during: # 1) inference (== not model.train) for RPN only and Faster R-CNN # OR @@ -108,15 +109,15 @@ def forward(self, x, im_info, roidb=None): rpn_cls_prob = rpn_cls_prob[:, 1].squeeze(dim=1) else: rpn_cls_prob = F.sigmoid(rpn_cls_logits) - + #hw获得候选区域 rpn_rois, rpn_rois_prob = self.RPN_GenerateProposals( rpn_cls_prob, rpn_bbox_pred, im_info) return_dict['rpn_rois'] = rpn_rois return_dict['rpn_roi_probs'] = rpn_rois_prob - if cfg.MODEL.FASTER_RCNN : - if self.training: + if cfg.MODEL.FASTER_RCNN or cfg.MODEL.LIGHT_HEAD_RCNN: + if self.training:#hw生成真实标签和预测数据的字典 # Add op that generates training labels for in-network RPN proposals blobs_out = self.RPN_GenerateProposalLabels(rpn_rois, roidb, im_info) return_dict.update(blobs_out) @@ -124,6 +125,10 @@ def forward(self, x, im_info, roidb=None): # Alias rois to rpn_rois for inference return_dict['rois'] = return_dict['rpn_rois'] + + + + return return_dict diff --git a/lib/roi_data/fast_rcnn.py b/lib/roi_data/fast_rcnn.py index afc09b40..21ac7cff 100644 --- a/lib/roi_data/fast_rcnn.py +++ b/lib/roi_data/fast_rcnn.py @@ -212,6 +212,10 @@ def _compute_targets(ex_rois, gt_rois, labels): targets = box_utils.bbox_transform_inv(ex_rois, gt_rois, cfg.MODEL.BBOX_REG_WEIGHTS) + # Use class "1" for all fg boxes if using class_agnostic_bbox_reg + + if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG: + labels.clip(max=1, out=labels) return np.hstack((labels[:, np.newaxis], targets)).astype( np.float32, copy=False) diff --git a/lib/utils/net.py b/lib/utils/net.py index 9497ec28..1d1b3dca 100644 --- a/lib/utils/net.py +++ b/lib/utils/net.py @@ -19,12 +19,13 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we 1 / N * sum_i alpha_out[i] * SmoothL1(alpha_in[i] * (y_hat[i] - y[i])). N is the number of batch elements in the input predictions """ + beta_2 = beta**2 box_diff = bbox_pred - bbox_targets in_box_diff = bbox_inside_weights * box_diff abs_in_box_diff = torch.abs(in_box_diff) - smoothL1_sign = (abs_in_box_diff < beta).detach().float() - in_loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \ - (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) + smoothL1_sign = (abs_in_box_diff < beta_2).detach().float() + in_loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta_2 + \ + (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta_2)) out_loss_box = bbox_outside_weights * in_loss_box loss_box = out_loss_box N = loss_box.size(0) # batch size diff --git a/tools/download_imagenet_weights.py b/tools/download_imagenet_weights.py index f01d0a62..449e54e1 100644 --- a/tools/download_imagenet_weights.py +++ b/tools/download_imagenet_weights.py @@ -36,6 +36,7 @@ def parse_args(): 'resnet101_caffe.pth': '1x2fTMqLrn63EMW0VuK4GEa2eQKzvJ_7l', 'resnet152_caffe.pth': '1NSCycOb7pU0KzluH326zmyMFUU55JslF', 'vgg16_caffe.pth': '19UphT53C0Ua9JAtICnw84PPTa3sZZ_9k', + 'vgg_SCNN_DULR_w9.t7':'1Wv3r3dCYNBwJdKl_WPEfrEOt-XGaROKu', } diff --git a/tools/infer_simple.py b/tools/infer_simple.py index dad16529..b2d0aea9 100644 --- a/tools/infer_simple.py +++ b/tools/infer_simple.py @@ -127,9 +127,11 @@ def main(): print("loading detectron weights %s" % args.load_detectron) load_detectron_weight(maskRCNN, args.load_detectron) + maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True, device_ids=[0]) # only support single GPU + maskRCNN.eval() if args.image_dir: imglist = misc_utils.get_imagelist_from_dir(args.image_dir) @@ -160,7 +162,7 @@ def main(): box_alpha=0.3, show_class=True, thresh=0.7, - kp_thresh=2 + kp_thresh=2, ) if args.merge_pdfs and num_images > 1: diff --git a/tools/train_net.py b/tools/train_net.py index 499ba39f..7e1df13b 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -38,6 +38,65 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + +def plot_graph(top_var, fname, params=None): + """ + This method don't support release v0.1.12 caused by a bug fixed in: https://github.com/pytorch/pytorch/pull/1016 + So if you want to use `plot_graph`, you have to build from master branch or wait for next release. + + Plot the graph. Make sure that require_grad=True and volatile=False + :param top_var: network output Varibale + :param fname: file name + :param params: dict of (name, Variable) to add names to node that + :return: png filename + """ + from graphviz import Digraph + import pydot + dot = Digraph(comment='LRP', + node_attr={'style': 'filled', 'shape': 'box'}) + # , 'fillcolor': 'lightblue'}) + + seen = set() + + if params is not None: + assert isinstance(params.values()[0], Variable) + param_map = {id(v): k for k, v in params.items()} + + def size_to_str(size): + return '(' + (', ').join(['%d' % v for v in size]) + ')' + + def add_nodes(var): + if var not in seen: + if torch.is_tensor(var): + dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') + elif hasattr(var, 'variable'): + u = var.variable + name = '{}\n '.format(param_map[id(u)]) if params is not None else '' + node_name = '{}{}'.format(name, size_to_str(u.size())) + dot.node(str(id(var)), node_name, fillcolor='lightblue') + else: + dot.node(str(id(var)), str(type(var).__name__)) + seen.add(var) + if hasattr(var, 'next_functions'): + for u in var.next_functions: + if u[0] is not None: + dot.edge(str(id(u[0])), str(id(var))) + add_nodes(u[0]) + if hasattr(var, 'saved_tensors'): + for t in var.saved_tensors: + dot.edge(str(id(t)), str(id(var))) + add_nodes(t) + + for o in top_var: + add_nodes(o.grad_fn) + dot.save(fname) + (graph,) = pydot.graph_from_dot_file(fname) + im_name = '{}.png'.format(fname) + graph.write_png(im_name) + print(im_name) + + return im_name + # RuntimeError: received 0 items of ancdata. Issue: pytorch/pytorch#973 rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) @@ -159,7 +218,7 @@ def main(): cfg.MODEL.NUM_CLASSES = 2 else: raise ValueError("Unexpected args.dataset: {}".format(args.dataset)) - + #hw from yml file merge to cfg cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) @@ -222,6 +281,7 @@ def main(): ### Model ### maskRCNN = Generalized_RCNN() + if cfg.CUDA: maskRCNN.cuda() @@ -297,6 +357,7 @@ def main(): if args.use_tfboard: from tensorboardX import SummaryWriter + from tensorboardX import FileWriter # Set the Tensorboard logger tblogger = SummaryWriter(output_dir) @@ -330,11 +391,13 @@ def main(): if key != 'roidb': # roidb is a list of ndarrays with inconsistent length input_data[key] = list(map(Variable, input_data[key])) + training_stats.IterTic() net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs) loss = net_outputs['total_loss'] optimizer.zero_grad() + loss.backward() optimizer.step() training_stats.IterToc() @@ -345,6 +408,7 @@ def main(): if args.step % args.disp_interval == 0: log_training_stats(training_stats, global_step, lr) + global_step += 1 # ---- End of epoch ----