Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions third_party/nvfuser/csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,28 @@ NVFUSER_DEFINE_BINARY_FLOAT_OP(div, Div)
NVFUSER_DEFINE_BINARY_FLOAT_OP(atan2, Atan2)
#undef NVFUSER_DEFINE_BINARY_FLOAT_OP

// These ops require full-precision floating point types (after float type
// promotion)
#define NVFUSER_DEFINE_BINARY_FLOAT_ONLY_OP(op_name, op_type) \
Val* op_name(Val* v1, Val* v2) { \
return binaryOp( \
BinaryOpType::op_type, v1, v2, TypePromotion::float_only_op_config); \
} \
TensorView* op_name(TensorView* v1, Val* v2) { \
return binaryOp( \
BinaryOpType::op_type, v1, v2, TypePromotion::float_only_op_config); \
} \
TensorView* op_name(Val* v1, TensorView* v2) { \
return binaryOp( \
BinaryOpType::op_type, v1, v2, TypePromotion::float_only_op_config); \
} \
TensorView* op_name(TensorView* v1, TensorView* v2) { \
return binaryOp( \
BinaryOpType::op_type, v1, v2, TypePromotion::float_only_op_config); \
}
NVFUSER_DEFINE_BINARY_FLOAT_ONLY_OP(nextafter, Nextafter)
#undef NVFUSER_DEFINE_BINARY_FLOAT_ONLY_OP

#define NVFUSER_DEFINE_BINARY_CAST_OP(op_name, op_type) \
Val* op_name(Val* v1, Val* v2) { \
return binaryOp( \
Expand Down
6 changes: 6 additions & 0 deletions third_party/nvfuser/csrc/ops/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ TORCH_CUDA_CU_API Val* sub(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* sub(TensorView* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* sub(Val* v1, TensorView* v2);
TORCH_CUDA_CU_API TensorView* sub(TensorView* v1, TensorView* v2);
// nextafter: Only single- or double-precision
// floating point types (after promotion) are supported.
TORCH_CUDA_CU_API Val* nextafter(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* nextafter(TensorView* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* nextafter(Val* v1, TensorView* v2);
TORCH_CUDA_CU_API TensorView* nextafter(TensorView* v1, TensorView* v2);
// Integer binary ops
// mod
TORCH_CUDA_CU_API Val* mod(Val* v1, Val* v2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_xor", bitwise_xor)
NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_left_shift", bitwise_left_shift)
NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_right_shift", bitwise_left_shift)
NVFUSER_PYTHON_BINDING_BINARY_OP("nextafter", nextafter)
#undef NVFUSER_PYTHON_BINDING_BINARY_OP

#define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name) \
Expand Down
2 changes: 2 additions & 0 deletions third_party/nvfuser/csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ static const char* binary_op_type2string(BinaryOpType t) {
return "remainder";
case BinaryOpType::Sub:
return "sub";
case BinaryOpType::Nextafter:
return "nextafter";

// Integer Ops
case BinaryOpType::Mod:
Expand Down
1 change: 1 addition & 0 deletions third_party/nvfuser/csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ enum class BinaryOpType {
Pow,
Remainder,
Sub,
Nextafter,
// TypeAs,

// Integer output ops. If changing modify isIntegerOp
Expand Down
10 changes: 10 additions & 0 deletions third_party/nvfuser/csrc/type_promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ c10::ScalarType computeTypes(
c10::isIntegralType(common_dtype, /*includeBool=*/true)) {
common_dtype = c10::get_default_dtype_as_scalartype();
}

// Some ops like nextafter are not implemented for non-float types
if (config.require_full_precision_promoted) {
TORCH_CHECK(
common_dtype == c10::ScalarType::Float ||
common_dtype == c10::ScalarType::Double,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that this is the right place to perform this check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I got this wrong too. PyTorch actually supports bfloat16, just not float16 for this op. It was added with a manual implementation taken from musl: https://github.com/pytorch/pytorch/pull/61829/files#diff-ece04c31934b3504382e10ed3e9a69f03ffabd81ad1a2a890aab19b1642f53c0R120

"Promoted type must be single or double precision float but found ",
common_dtype);
}

return common_dtype;
}

Expand Down
4 changes: 4 additions & 0 deletions third_party/nvfuser/csrc/type_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace nvfuser {
//!
struct TypePromotionConfig {
bool promote_integer_inputs_to_float = false;
bool require_full_precision_promoted = false;
TypePromotionConfig() = default;
};

Expand All @@ -31,6 +32,9 @@ static const TypePromotionConfig comparison_op_config;
static const TypePromotionConfig default_op_config;
static const TypePromotionConfig float_op_config{
/* promote_integer_inputs_to_float */ true};
static const TypePromotionConfig float_only_op_config{
/* promote_integer_inputs_to_float */ false,
/* require_full_precision_promoted */ true};

} // namespace TypePromotion

Expand Down
43 changes: 43 additions & 0 deletions third_party/nvfuser/python_tests/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from copy import deepcopy
from functools import partial
import itertools
import re
from typing import List
import unittest

import torch
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase
from torch.testing._internal.jit_utils import RUN_CUDA
from torch.testing import make_tensor
import torch._refs as refs
import torch._prims as prims
# Will only create the nvfuser module if CUDA is available
Expand Down Expand Up @@ -987,5 +989,46 @@ def fusion_func(fd: FusionDefinition):
eager_out = torch.full([2, 2], 1.0) * 5.0
self.assertEqual(eager_out, nvf_out[0])

def test_nextafter(self):
inputs = [
# torch.nextafter is only defined for float{32,64} tensor inputs
make_tensor(4, device="cuda", dtype=torch.float32),
make_tensor(4, device="cuda", dtype=torch.float64),
]

def fusion_func(fd: FusionDefinition):
t0 = fd.from_pytorch(inputs[0])
t1 = fd.from_pytorch(inputs[1])

s0 = fd.define_constant(1.0, dtype=DataType.Float)
s1 = fd.define_constant(-1.0, dtype=DataType.Double)

t2 = fd.ops.add(t0, s0) # float
t3 = fd.ops.add(t1, s1) # double

for a, b in itertools.product(
[t0, t1, s0, s1],
[t0, t1, s0, s1],
):
# always enter the fusion...
t = fd.ops.nextafter(a, b)
if a in [t0, t1] or b in [t0, t1]:
# ...but skip outputting scalars, which we don't support
fd.add_output(t)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

ab = [inputs[0], inputs[1], 1.0, -1.0]
i = 0
for a, b in itertools.product(ab, ab):
if not (isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor)):
continue
n = nvf_out[i]
i += 1
torch_out = torch.nextafter(
torch.as_tensor(a, device='cuda'), torch.as_tensor(b, device='cuda')
)
self.assertEqual(n, torch_out)

if __name__ == '__main__':
run_tests()
8 changes: 8 additions & 0 deletions third_party/nvfuser/runtime/helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ __device__ constexpr float fmod(float a, float b) {
return ::fmod(a, b);
}

__device__ constexpr double nextafter(double a, double b) {
return ::nextafter(a, b);
}

__device__ constexpr float nextafter(float a, float b) {
return ::nextafterf(a, b);
}

template <typename T>
__device__ T pow(T a, T b) {
if (b < 0) {
Expand Down