Skip to content

Commit 9eb8c9e

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] Move sitofp lowering to Mosaic.
Also, the compatibility check in fptosi is likely wrong - we should check if both src and dst bit widths < 32, not just dst. Correct it while I'm here. PiperOrigin-RevId: 756816853
1 parent e194d53 commit 9eb8c9e

File tree

3 files changed

+76
-23
lines changed

3 files changed

+76
-23
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,11 +2177,14 @@ def _convert_element_type_lowering_rule(
21772177
elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
21782178
# This case triggers when casting signed to unsigned or vice versa.
21792179
return x
2180-
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
21812180
elif _from(floating) and _to(signed):
21822181
return arith.fptosi(out_type, x)
2183-
elif _from(signed) and _to(floating) and both_32bit:
2184-
return arith.sitofp(out_type, x)
2182+
elif _from(signed) and _to(floating):
2183+
if (
2184+
not (ctx.forward_compatible or is_cloud_tpu_older_than(2025, 5, 12))
2185+
or both_32bit
2186+
):
2187+
return arith.sitofp(out_type, x)
21852188
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
21862189
return arith.extui(out_type, x)
21872190
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ limitations under the License.
4444
#include "mlir/IR/ImplicitLocOpBuilder.h"
4545
#include "mlir/IR/OpDefinition.h"
4646
#include "mlir/IR/Operation.h"
47-
#include "mlir/IR/PatternMatch.h"
4847
#include "mlir/IR/Region.h"
4948
#include "mlir/IR/Value.h"
5049
#include "mlir/Pass/Pass.h"
5150
#include "mlir/Support/LLVM.h"
5251
#include "mlir/Support/LogicalResult.h"
5352
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
53+
#include "jaxlib/mosaic/dialect/tpu/util.h"
5454
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
5555

5656
namespace mlir::tpu {
@@ -601,14 +601,10 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
601601
return op.emitOpError("Vector/scalar mismatch between input and output");
602602
}
603603
bool is_vector = static_cast<bool>(src_vty);
604-
unsigned src_bitwidth, dst_bitwidth;
605-
if (is_vector) {
606-
src_bitwidth = src_vty.getElementTypeBitWidth();
607-
dst_bitwidth = dst_vty.getElementTypeBitWidth();
608-
} else {
609-
src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth();
610-
dst_bitwidth = op.getType().getIntOrFloatBitWidth();
611-
}
604+
FAILUREOR_ASSIGN_OR_RETURN(const unsigned src_bitwidth,
605+
getElementTypeBitwidth(op.getIn().getType()));
606+
FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth,
607+
getElementTypeBitwidth(op.getType()));
612608
if (dst_bitwidth > 32) {
613609
return op.emitOpError("Target bitwidth too large");
614610
}
@@ -623,6 +619,14 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
623619
op.erase();
624620
return success();
625621
}
622+
623+
if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) {
624+
return op.emitOpError(
625+
"On this target float-to-integer conversions can only happen on "
626+
"32-bit values. Enable compatibility mode or upcast to float32, cast "
627+
"to int32 and truncate to desired bitwidth.");
628+
}
629+
626630
Value x = op.getIn();
627631
// Upcast the input to f32.
628632
if (src_bitwidth < 32) {
@@ -634,11 +638,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
634638
}
635639
}
636640
if (dst_bitwidth < 32) {
637-
if (!ctx.compatibility_mode) {
638-
return op.emitOpError(
639-
"On this target only float-to-integer conversions can only happen on "
640-
"32-bit values. Enable compatibility mode or upcast to float32.");
641-
}
642641
// Need to clip values to match XLA
643642
auto clip = [&](Value x, Value low, Value high) {
644643
x = builder.create<arith::MaximumFOp>(x, low);
@@ -666,19 +665,59 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
666665
x = builder.create<arith::FPToSIOp>(builder.getI32Type(), x);
667666
}
668667
if (dst_bitwidth < 32) {
669-
if (!ctx.compatibility_mode) {
670-
return op.emitOpError(
671-
"On this target only float-to-integer conversions can only happen on "
672-
"32-bit values. Enable compatibility mode or cast to int32 and "
673-
"truncate later.");
674-
}
675668
x = builder.create<arith::TruncIOp>(op.getType(), x);
676669
}
677670
op.replaceAllUsesWith(x);
678671
op.erase();
679672
return success();
680673
}
681674

675+
LogicalResult canonicalize_sitofp(const CanonicalizeContext &ctx,
676+
Operation &raw_op) {
677+
auto op = cast<arith::FPToSIOp>(raw_op);
678+
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
679+
auto src_vty = dyn_cast<VectorType>(op.getIn().getType());
680+
auto dst_vty = dyn_cast<VectorType>(op.getType());
681+
if (static_cast<bool>(src_vty) != static_cast<bool>(dst_vty)) {
682+
return op.emitOpError("Vector/scalar mismatch between input and output");
683+
}
684+
bool is_vector = static_cast<bool>(src_vty);
685+
FAILUREOR_ASSIGN_OR_RETURN(const unsigned src_bitwidth,
686+
getElementTypeBitwidth(op.getIn().getType()));
687+
FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth,
688+
getElementTypeBitwidth(op.getType()));
689+
690+
if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) {
691+
return op.emitOpError(
692+
"On this target integer-to-float conversions can only happen on "
693+
"32-bit values. Enable compatibility mode or upcast to int32, cast to "
694+
"float32 and truncate to desired bitwidth.");
695+
}
696+
697+
// Canonicalize (intX -> floatY) to (intX -> int32 -> float32 -> floatY).
698+
Value x = op.getIn();
699+
if (src_bitwidth < 32) {
700+
if (is_vector) {
701+
x = builder.create<arith::ExtSIOp>(
702+
VectorType::get(src_vty.getShape(), builder.getI32Type()), x);
703+
} else {
704+
x = builder.create<arith::ExtSIOp>(builder.getI32Type(), x);
705+
}
706+
}
707+
if (is_vector) {
708+
x = builder.create<arith::SIToFPOp>(
709+
VectorType::get(src_vty.getShape(), builder.getF32Type()), x);
710+
} else {
711+
x = builder.create<arith::SIToFPOp>(builder.getF32Type(), x);
712+
}
713+
if (dst_bitwidth < 32) {
714+
x = builder.create<arith::TruncFOp>(op.getType(), x);
715+
}
716+
op.replaceAllUsesWith(x);
717+
op.erase();
718+
return success();
719+
}
720+
682721
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
683722
Operation &raw_op) {
684723
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
@@ -727,6 +766,7 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
727766
{vector::TransposeOp::getOperationName(), canonicalize_vector_transpose},
728767
{arith::SelectOp::getOperationName(), canonicalize_select},
729768
{arith::FPToSIOp::getOperationName(), canonicalize_fptosi},
769+
{arith::SIToFPOp::getOperationName(), canonicalize_sitofp},
730770
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
731771
return *rules;
732772
}

jaxlib/mosaic/dialect/tpu/util.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,16 @@ FailureOr<int8_t> getTypeBitwidth(Type ty) {
180180
<< ty;
181181
}
182182

183+
// Returns the bitwidth of the element type. The function works for both
184+
// scalar and vector types.
185+
template <bool adjust_bool = false>
186+
inline FailureOr<int8_t> getElementTypeBitwidth(Type ty) {
187+
if (auto vty = dyn_cast<VectorType>(ty)) {
188+
return getTypeBitwidth<adjust_bool>(vty.getElementType());
189+
}
190+
return getTypeBitwidth<adjust_bool>(ty);
191+
}
192+
183193
template <typename T>
184194
ArrayRef<std::remove_const_t<T>> toArrayRef(absl::Span<T> span) {
185195
return ArrayRef<std::remove_const_t<T>>(span.data(), span.size());

0 commit comments

Comments
 (0)