@@ -1159,7 +1159,8 @@ struct HelperInterpLinear : public HelperInterpBase {
11591159 int64_t ndims,
11601160 int64_t reshape_dim,
11611161 bool align_corners,
1162- const c10::optional<double > opt_scale
1162+ const c10::optional<double > opt_scale,
1163+ bool antialias
11631164 ) {
11641165
11651166 std::vector<Tensor> indices_weights;
@@ -1172,6 +1173,7 @@ struct HelperInterpLinear : public HelperInterpBase {
11721173 auto interp_size = HelperInterpLinear::interp_size;
11731174 int unused;
11741175 scalar_t unused_2;
1176+ auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0 ;
11751177
11761178 std::tie (indices_weights, unused, unused_2) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t >(
11771179 input_size,
@@ -1182,8 +1184,8 @@ struct HelperInterpLinear : public HelperInterpBase {
11821184 scale,
11831185 interp_size,
11841186 &HelperInterpLinear::aa_filter<scalar_t >,
1185- /* antialias=*/ true ,
1186- /* align_corners_delta=*/ 0.0 );
1187+ /* antialias=*/ antialias ,
1188+ /* align_corners_delta=*/ align_corners_delta );
11871189 }
11881190 );
11891191 return indices_weights;
@@ -1293,7 +1295,8 @@ struct HelperInterpCubic : public HelperInterpBase {
12931295 int64_t ndims,
12941296 int64_t reshape_dim,
12951297 bool align_corners,
1296- const c10::optional<double > opt_scale
1298+ const c10::optional<double > opt_scale,
1299+ bool antialias
12971300 ) {
12981301
12991302 std::vector<Tensor> indices_weights;
@@ -1306,6 +1309,7 @@ struct HelperInterpCubic : public HelperInterpBase {
13061309 auto interp_size = HelperInterpCubic::interp_size;
13071310 int unused;
13081311 scalar_t unused_2;
1312+ auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0 ;
13091313
13101314 std::tie (indices_weights, unused, unused_2) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t >(
13111315 input_size,
@@ -1316,8 +1320,8 @@ struct HelperInterpCubic : public HelperInterpBase {
13161320 scale,
13171321 interp_size,
13181322 &HelperInterpCubic::aa_filter<scalar_t >,
1319- /* antialias=*/ true ,
1320- /* align_corners_delta*/ 0.0 );
1323+ /* antialias=*/ antialias ,
1324+ /* align_corners_delta*/ align_corners_delta );
13211325 }
13221326 );
13231327 return indices_weights;
@@ -1475,22 +1479,22 @@ void _separable_upsample_generic_Nd_kernel_impl_single_dim(
14751479 unsigned int weights_precision = 0 ;
14761480 int unused;
14771481
1478- if (input_scalar_type == at::kByte ) {
1482+ if (F::interp_size == 2 && input_scalar_type == at::kByte ) {
1483+ // This is special branch to provide uint8 dtype support for bilinear mode only
14791484 std::tie (indices_weights, unused, weights_precision) =
1480- // TODO: change that to F:: once / if bicubic mode supports uint8 after all
14811485 HelperInterpLinear::compute_indices_int16_weights_aa (
14821486 input.size (interp_dim), oshape[interp_dim],
14831487 input.stride (interp_dim) * input.element_size (),
14841488 input.dim (), interp_dim, align_corners, scales[interp_dim - 2 ],
14851489 antialias);
14861490 TORCH_INTERNAL_ASSERT (weights_precision > 0 );
14871491 } else {
1488- TORCH_INTERNAL_ASSERT (antialias);
14891492 indices_weights =
14901493 F::compute_indices_weights_aa (
14911494 input_scalar_type, input.size (interp_dim), oshape[interp_dim],
14921495 input.stride (interp_dim) * input.element_size (),
1493- input.dim (), interp_dim, align_corners, scales[interp_dim - 2 ]);
1496+ input.dim (), interp_dim, align_corners, scales[interp_dim - 2 ],
1497+ antialias);
14941498 }
14951499
14961500 TensorIteratorConfig config;
@@ -1801,6 +1805,11 @@ void upsample_bicubic2d_kernel_impl(
18011805 bool align_corners,
18021806 c10::optional<double > scales_h,
18031807 c10::optional<double > scales_w) {
1808+
1809+ // We explicitly checking for non-supported uint8 dtype
1810+ TORCH_CHECK (input.scalar_type () != at::kByte ,
1811+ " 'upsample_bicubic2d_aa_kernel_impl' not implemented for 'Byte'" );
1812+
18041813 upsample_generic_Nd_kernel_impl<2 , scale_t , HelperInterpCubic>(
18051814 output, input, align_corners, {scales_h, scales_w});
18061815}
@@ -1812,6 +1821,10 @@ void upsample_bicubic2d_aa_kernel_impl(
18121821 c10::optional<double > scales_h,
18131822 c10::optional<double > scales_w) {
18141823
1824+ // We explicitly checking for non-supported uint8 dtype
1825+ TORCH_CHECK (input.scalar_type () != at::kByte ,
1826+ " 'upsample_bicubic2d_aa_kernel_impl' not implemented for 'Byte'" );
1827+
18151828 separable_upsample_generic_Nd_kernel_impl<2 , scale_t , HelperInterpCubic>(
18161829 output, input, align_corners, {scales_h, scales_w},
18171830 /* antialias=*/ true );
0 commit comments