Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/inductor/test_op_dtype_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtyp
triton_op_name_overrides = {
"round": "nearbyint",
}
# ROCm uses fast_tahnf for everything input types that are not float64
if torch.version.hip and input_dtype != torch.float64:
triton_op_name_overrides["tanh"] = "fast_tanhf"
override = triton_op_name_overrides.get(op_name)
triton_op_name = override if override is not None else torch_op_name

Expand Down
11 changes: 10 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,16 @@ def tan(x):
@staticmethod
@maybe_upcast_float32()
def tanh(x):
if torch.version.hip and get_triton_version() > (3, 2):
cse_var = V.kernel.cse.varname_map.get(x)
if cse_var and hasattr(cse_var, "dtype"):
dtype = cse_var.dtype
else:
dtype = None
if (
torch.version.hip
and get_triton_version() > (3, 2)
and dtype != torch.float64
):
# On ROCm, use fast_tanhf depending on Triton version
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
return f"libdevice.fast_tanhf({x})"
Expand Down