Skip to content

Commit 7042e10

Browse files
vfdev-5pytorchmergebot
authored andcommitted
Fixed issue with bicubic interpolation on uint8 input and antialising (pytorch#102296)
Description: - Fixed issue with bicubic interpolation on uint8 input and antialising, discovered by @NicolasHug - Unified `_separable_upsample_generic_Nd_kernel_impl_single_dim` on `antialis` arg. Pull Request resolved: pytorch#102296 Approved by: https://github.com/NicolasHug
1 parent 0f1621d commit 7042e10

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

aten/src/ATen/native/cpu/UpSampleKernel.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,8 @@ struct HelperInterpLinear : public HelperInterpBase {
11591159
int64_t ndims,
11601160
int64_t reshape_dim,
11611161
bool align_corners,
1162-
const c10::optional<double> opt_scale
1162+
const c10::optional<double> opt_scale,
1163+
bool antialias
11631164
) {
11641165

11651166
std::vector<Tensor> indices_weights;
@@ -1172,6 +1173,7 @@ struct HelperInterpLinear : public HelperInterpBase {
11721173
auto interp_size = HelperInterpLinear::interp_size;
11731174
int unused;
11741175
scalar_t unused_2;
1176+
auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0;
11751177

11761178
std::tie(indices_weights, unused, unused_2) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t>(
11771179
input_size,
@@ -1182,8 +1184,8 @@ struct HelperInterpLinear : public HelperInterpBase {
11821184
scale,
11831185
interp_size,
11841186
&HelperInterpLinear::aa_filter<scalar_t>,
1185-
/*antialias=*/true,
1186-
/*align_corners_delta=*/0.0);
1187+
/*antialias=*/antialias,
1188+
/*align_corners_delta=*/align_corners_delta);
11871189
}
11881190
);
11891191
return indices_weights;
@@ -1293,7 +1295,8 @@ struct HelperInterpCubic : public HelperInterpBase {
12931295
int64_t ndims,
12941296
int64_t reshape_dim,
12951297
bool align_corners,
1296-
const c10::optional<double> opt_scale
1298+
const c10::optional<double> opt_scale,
1299+
bool antialias
12971300
) {
12981301

12991302
std::vector<Tensor> indices_weights;
@@ -1306,6 +1309,7 @@ struct HelperInterpCubic : public HelperInterpBase {
13061309
auto interp_size = HelperInterpCubic::interp_size;
13071310
int unused;
13081311
scalar_t unused_2;
1312+
auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0;
13091313

13101314
std::tie(indices_weights, unused, unused_2) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t>(
13111315
input_size,
@@ -1316,8 +1320,8 @@ struct HelperInterpCubic : public HelperInterpBase {
13161320
scale,
13171321
interp_size,
13181322
&HelperInterpCubic::aa_filter<scalar_t>,
1319-
/*antialias=*/true,
1320-
/*align_corners_delta*/0.0);
1323+
/*antialias=*/antialias,
1324+
/*align_corners_delta*/align_corners_delta);
13211325
}
13221326
);
13231327
return indices_weights;
@@ -1475,22 +1479,22 @@ void _separable_upsample_generic_Nd_kernel_impl_single_dim(
14751479
unsigned int weights_precision = 0;
14761480
int unused;
14771481

1478-
if (input_scalar_type == at::kByte) {
1482+
if (F::interp_size == 2 && input_scalar_type == at::kByte) {
1483+
// This is special branch to provide uint8 dtype support for bilinear mode only
14791484
std::tie(indices_weights, unused, weights_precision) =
1480-
// TODO: change that to F:: once / if bicubic mode supports uint8 after all
14811485
HelperInterpLinear::compute_indices_int16_weights_aa(
14821486
input.size(interp_dim), oshape[interp_dim],
14831487
input.stride(interp_dim) * input.element_size(),
14841488
input.dim(), interp_dim, align_corners, scales[interp_dim - 2],
14851489
antialias);
14861490
TORCH_INTERNAL_ASSERT(weights_precision > 0);
14871491
} else {
1488-
TORCH_INTERNAL_ASSERT(antialias);
14891492
indices_weights =
14901493
F::compute_indices_weights_aa(
14911494
input_scalar_type, input.size(interp_dim), oshape[interp_dim],
14921495
input.stride(interp_dim) * input.element_size(),
1493-
input.dim(), interp_dim, align_corners, scales[interp_dim - 2]);
1496+
input.dim(), interp_dim, align_corners, scales[interp_dim - 2],
1497+
antialias);
14941498
}
14951499

14961500
TensorIteratorConfig config;
@@ -1801,6 +1805,11 @@ void upsample_bicubic2d_kernel_impl(
18011805
bool align_corners,
18021806
c10::optional<double> scales_h,
18031807
c10::optional<double> scales_w) {
1808+
1809+
// We explicitly checking for non-supported uint8 dtype
1810+
TORCH_CHECK(input.scalar_type() != at::kByte,
1811+
"'upsample_bicubic2d_aa_kernel_impl' not implemented for 'Byte'");
1812+
18041813
upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
18051814
output, input, align_corners, {scales_h, scales_w});
18061815
}
@@ -1812,6 +1821,10 @@ void upsample_bicubic2d_aa_kernel_impl(
18121821
c10::optional<double> scales_h,
18131822
c10::optional<double> scales_w) {
18141823

1824+
// We explicitly checking for non-supported uint8 dtype
1825+
TORCH_CHECK(input.scalar_type() != at::kByte,
1826+
"'upsample_bicubic2d_aa_kernel_impl' not implemented for 'Byte'");
1827+
18151828
separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
18161829
output, input, align_corners, {scales_h, scales_w},
18171830
/*antialias=*/true);

test/test_nn.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torch.nn.utils.fusion import fuse_linear_bn_weights
3434
from torch.nn import Parameter
3535
from torch.nn.parallel._functions import Broadcast
36-
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes
36+
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
3737
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
3838
TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
3939
download_file, get_function_arglist, load_tests, skipIfMps,\
@@ -9683,6 +9683,35 @@ def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory
96839683

96849684
self.assertEqual(a_cuda.grad, a_cpu.grad)
96859685

9686+
@parametrize_test("antialias", [True, False])
9687+
@parametrize_test("num_channels", [3, 5])
9688+
@parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"])
9689+
@parametrize_test("dtype", integral_types() + floating_types())
9690+
@onlyNativeDeviceTypes
9691+
def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype):
9692+
x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device)
9693+
9694+
should_raise_runtime_error = True
9695+
9696+
if "nearest" in mode:
9697+
if antialias:
9698+
raise SkipTest("Nearest mode does not have antialiasing")
9699+
if dtype in (torch.uint8, ) + floating_types():
9700+
should_raise_runtime_error = False
9701+
9702+
elif mode == "bilinear":
9703+
if dtype in floating_types() or (device == "cpu" and dtype == torch.uint8):
9704+
should_raise_runtime_error = False
9705+
elif mode == "bicubic":
9706+
if dtype in floating_types():
9707+
should_raise_runtime_error = False
9708+
9709+
if should_raise_runtime_error:
9710+
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
9711+
F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
9712+
else:
9713+
_ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
9714+
96869715
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
96879716
def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
96889717
t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8)

0 commit comments

Comments
 (0)