diff --git a/src/frontends/pytorch/src/op/rand.cpp b/src/frontends/pytorch/src/op/rand.cpp index cef77ee5811093..ef7bd8d80e0d79 100644 --- a/src/frontends/pytorch/src/op/rand.cpp +++ b/src/frontends/pytorch/src/op/rand.cpp @@ -6,15 +6,30 @@ #include "openvino/frontend/common/random_normal_helper.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/convert_like.hpp" #include "openvino/op/cos.hpp" +#include "openvino/op/cum_sum.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/less.hpp" #include "openvino/op/log.hpp" +#include "openvino/op/logical_and.hpp" +#include "openvino/op/maximum.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/power.hpp" #include "openvino/op/random_uniform.hpp" +#include "openvino/op/reduce_logical_or.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/sqrt.hpp" +#include "openvino/op/subtract.hpp" #include "pt_framework_node.hpp" #include "transformations/rt_info/disable_fp16_compression.hpp" #include "utils.hpp" @@ -27,6 +42,8 @@ namespace op { using namespace ov::op; namespace { +constexpr int64_t standard_gamma_trials = 16; +constexpr float min_uniform_value = 1e-7f; OutputVector make_random_normal(const NodeContext& context, const Output& sizes, element::Type target_type, @@ -267,6 +284,120 @@ OutputVector translate_randint(const NodeContext& context) { return {res}; }; +OutputVector translate_standard_gamma(const NodeContext& context) { + // aten::_standard_gamma(Tensor self, *, Generator? generator=None) -> Tensor + num_inputs_check(context, 1, 2); + if (context.get_input_size() == 2) { + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(1), + "aten::_standard_gamma conversion with generator is not supported"); + } + + auto input = context.get_input(0); + auto output_type = input.get_element_type(); + auto concentration = input; + if (output_type != element::f32) { + concentration = context.mark_node(std::make_shared(input, element::f32)); + } + + auto shape_i32 = context.mark_node(std::make_shared(concentration, element::i32)); + auto shape = context.mark_node(std::make_shared(shape_i32, element::i64)); + auto trials = + context.mark_node(v0::Constant::create(element::i64, Shape{1}, {standard_gamma_trials})); + auto expanded_shape = + context.mark_node(std::make_shared(OutputVector{trials, shape}, 0)); + auto axis_zero_i64 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0})); + auto axis_zero_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + + auto zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0.f})); + auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1.f})); + auto half = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0.5f})); + auto third = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1.f / 3.f})); + auto nine = context.mark_node(v0::Constant::create(element::f32, Shape{}, {9.f})); + auto min_uniform = + context.mark_node(v0::Constant::create(element::f32, Shape{}, {min_uniform_value})); + + auto lt_one_mask = context.mark_node(std::make_shared(concentration, one)); + auto conc_plus_one = context.mark_node(std::make_shared(concentration, one)); + auto conc_ge_one = context.mark_node(std::make_shared(lt_one_mask, conc_plus_one, concentration)); + + auto d = context.mark_node(std::make_shared(conc_ge_one, third)); + auto nine_d = context.mark_node(std::make_shared(d, nine)); + auto sqrt_term = context.mark_node(std::make_shared(nine_d)); + auto c = context.mark_node(std::make_shared(one, sqrt_term)); + + auto scale = one; + auto mean = zero; + auto normals = make_random_normal(context, expanded_shape, element::f32, scale, mean)[0]; + auto uniform_accept = + context.mark_node(std::make_shared(expanded_shape, min_uniform, one, element::f32)); + + auto zero_bc = context.mark_node(std::make_shared(zero, expanded_shape)); + auto one_bc = context.mark_node(std::make_shared(one, expanded_shape)); + auto min_uniform_bc = + context.mark_node(std::make_shared(min_uniform, expanded_shape)); + auto d_bc = context.mark_node(std::make_shared(d, expanded_shape)); + auto c_bc = context.mark_node(std::make_shared(c, expanded_shape)); + + auto cx = context.mark_node(std::make_shared(c_bc, normals)); + auto one_plus_cx = context.mark_node(std::make_shared(one_bc, cx)); + auto v = context.mark_node(std::make_shared(one_plus_cx, one_plus_cx)); + v = context.mark_node(std::make_shared(v, one_plus_cx)); + auto safe_v = context.mark_node(std::make_shared(v, min_uniform_bc)); + + auto log_v = context.mark_node(std::make_shared(safe_v)); + auto log_u = context.mark_node(std::make_shared(uniform_accept)); + auto x_sq = context.mark_node(std::make_shared(normals, normals)); + auto x_sq_half = context.mark_node(std::make_shared(x_sq, half)); + + auto d_times_v = context.mark_node(std::make_shared(d_bc, v)); + auto d_minus_dv = context.mark_node(std::make_shared(d_bc, d_times_v)); + auto d_log_v = context.mark_node(std::make_shared(d_bc, log_v)); + auto rhs = context.mark_node(std::make_shared(x_sq_half, d_minus_dv)); + rhs = context.mark_node(std::make_shared(rhs, d_log_v)); + + auto positive_mask = context.mark_node(std::make_shared(v, zero_bc)); + auto compare_mask = context.mark_node(std::make_shared(log_u, rhs)); + auto accept_mask = context.mark_node(std::make_shared(positive_mask, compare_mask)); + + auto candidate = context.mark_node(std::make_shared(d_bc, v)); + auto accept_i32 = context.mark_node(std::make_shared(accept_mask, element::i32)); + auto prefix = context.mark_node(std::make_shared(accept_i32, axis_zero_i32, false, false)); + auto one_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto one_i32_bc = context.mark_node(std::make_shared(one_i32, expanded_shape)); + auto first_accept = context.mark_node( + std::make_shared(accept_mask, + context.mark_node(std::make_shared(prefix, one_i32_bc)))); + + auto first_accept_f = context.mark_node(std::make_shared(first_accept, element::f32)); + auto selected = context.mark_node(std::make_shared(candidate, first_accept_f)); + auto gamma_candidates = + context.mark_node(std::make_shared(selected, axis_zero_i64, false)); + auto any_accept = + context.mark_node(std::make_shared(accept_mask, axis_zero_i64, false)); + + auto last_index = + context.mark_node(v0::Constant::create(element::i64, Shape{}, {standard_gamma_trials - 1})); + auto last_candidate = + context.mark_node(std::make_shared(candidate, last_index, axis_zero_i64)); + auto gamma_base = + context.mark_node(std::make_shared(any_accept, gamma_candidates, last_candidate)); + + auto frac_uniform = + context.mark_node(std::make_shared(shape, min_uniform, one, element::f32)); + auto safe_alpha = context.mark_node(std::make_shared(concentration, min_uniform)); + auto alpha_for_inv = context.mark_node(std::make_shared(lt_one_mask, safe_alpha, one)); + auto inv_alpha = context.mark_node(std::make_shared(one, alpha_for_inv)); + auto pow_term = context.mark_node(std::make_shared(frac_uniform, inv_alpha)); + auto adjusted = context.mark_node(std::make_shared(gamma_base, pow_term)); + auto gamma_fp32 = context.mark_node(std::make_shared(lt_one_mask, adjusted, gamma_base)); + + Output result = gamma_fp32; + if (output_type != element::f32) { + result = context.mark_node(std::make_shared(result, input)); + } + return {result}; +}; + OutputVector translate_normal_(const NodeContext& context) { // aten::normal_(Tensor(a!) self, float mean=0., float std=1., *, Generator? generator=None) -> Tensor(a!) num_inputs_check(context, 3, 4); diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index bba8dd6fa81039..f854ad4e5a4bcb 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -252,6 +252,7 @@ OP_CONVERTER(translate_split_with_sizes); OP_CONVERTER(translate_square); OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_std); +OP_CONVERTER(translate_standard_gamma); OP_CONVERTER(translate_std_mean); OP_CONVERTER(translate_stft); OP_CONVERTER(translate_sub); @@ -385,6 +386,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::_pad_packed_sequence", op::translate_pad_packed_sequence}, {"aten::_set_item", op::translate_set_item}, {"aten::_shape_as_tensor", op::translate_shape_as_tensor}, + {"aten::_standard_gamma", op::translate_standard_gamma}, {"aten::_unique2", op::translate_unique2}, {"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa}, {"aten::_upsample_bilinear2d_aa", op::translate_upsample_bilinear2d_aa}, diff --git a/tests/layer_tests/pytorch_tests/test_standard_gamma.py b/tests/layer_tests/pytorch_tests/test_standard_gamma.py new file mode 100644 index 00000000000000..0c53cd293c9fe5 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_standard_gamma.py @@ -0,0 +1,107 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +import openvino as ov + + +class TestStandardGammaStatistics: + class AtenStandardGamma(torch.nn.Module): + def forward(self, alpha): + return torch._standard_gamma(alpha) + + def _run_gamma_stat_test( + self, + alpha_value, + shape, + mean_rtol, + mean_atol, + var_rtol, + var_atol, + ie_device, + precision, + ): + torch.manual_seed(0) + np.random.seed(0) + + model = self.AtenStandardGamma() + alpha_np = np.full(shape, alpha_value, dtype=np.float32) + alpha_tensor = torch.from_numpy(alpha_np) + + ov_model = ov.convert_model(input_model=model, example_input=(alpha_tensor,)) + config = ( + {"INFERENCE_PRECISION_HINT": "f32"} + if ie_device == "GPU" and precision == "FP32" + else {} + ) + compiled_model = ov.Core().compile_model(ov_model, ie_device, config) + + with torch.no_grad(): + fw_samples = model(alpha_tensor).detach().cpu().numpy().reshape(-1) + + infer_request = compiled_model.create_infer_request() + infer_request.infer({compiled_model.input(0): alpha_np}) + ov_samples = infer_request.get_output_tensor(0).data.reshape(-1) + + assert np.isfinite(fw_samples).all(), "PyTorch gamma samples contain non-finite values" + assert np.isfinite(ov_samples).all(), "OpenVINO gamma samples contain non-finite values" + + expected_mean = alpha_value + expected_var = alpha_value + + np.testing.assert_allclose( + fw_samples.mean(), + expected_mean, + rtol=mean_rtol, + atol=mean_atol, + ) + np.testing.assert_allclose( + fw_samples.var(), + expected_var, + rtol=var_rtol, + atol=var_atol, + ) + np.testing.assert_allclose( + ov_samples.mean(), + expected_mean, + rtol=mean_rtol, + atol=mean_atol, + ) + np.testing.assert_allclose( + ov_samples.var(), + expected_var, + rtol=var_rtol, + atol=var_atol, + ) + + @pytest.mark.precommit + @pytest.mark.parametrize( + "alpha_value,shape,mean_rtol,mean_atol,var_rtol,var_atol", + [ + (0.25, (10_000,), 2e-2, 2e-2, 2e-1, 2e-2), + (1.0, (10_000,), 2e-2, 2e-2, 2e-1, 2e-2), + ], + ) + def test_standard_gamma_statistics_precommit( + self, alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision + ): + self._run_gamma_stat_test( + alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision + ) + + @pytest.mark.nightly + @pytest.mark.parametrize( + "alpha_value,shape,mean_rtol,mean_atol,var_rtol,var_atol", + [ + (0.25, (200_000,), 5e-3, 5e-3, 1e-1, 2e-2), + (7.5, (50_000,), 1e-2, 1e-2, 1e-1, 2e-2), + ], + ) + def test_standard_gamma_statistics_nightly( + self, alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision + ): + self._run_gamma_stat_test( + alpha_value, shape, mean_rtol, mean_atol, var_rtol, var_atol, ie_device, precision + )