Skip to content

Commit 703556c

Browse files
Fix
Signed-off-by: Oleg Goncharov <[email protected]>
1 parent b764dea commit 703556c

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

transformer_engine/common/cast/fp8/gated_fp8.cuh

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,6 @@ void cast_gated_tma(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
281281
using namespace kernel;
282282
checkCuDriverContext(stream);
283283

284-
if (output->has_data()) {
285-
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
286-
}
287-
if (output->has_columnwise_data()) {
288-
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
289-
}
290-
291284
NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");
292285
const size_t rows = gated_input.flat_first_dim();
293286
const size_t cols = gated_input.flat_last_dim() / 2;
@@ -305,7 +298,7 @@ void cast_gated_tma(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
305298

306299
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
307300
gated_input.dtype(), IType,
308-
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
301+
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
309302
output->dtype(), OType,
310303

311304
alignas(64) CUtensorMap tensor_map_grad{};

0 commit comments

Comments
 (0)