@@ -44,13 +44,13 @@ limitations under the License.
44
44
#include " mlir/IR/ImplicitLocOpBuilder.h"
45
45
#include " mlir/IR/OpDefinition.h"
46
46
#include " mlir/IR/Operation.h"
47
- #include " mlir/IR/PatternMatch.h"
48
47
#include " mlir/IR/Region.h"
49
48
#include " mlir/IR/Value.h"
50
49
#include " mlir/Pass/Pass.h"
51
50
#include " mlir/Support/LLVM.h"
52
51
#include " mlir/Support/LogicalResult.h"
53
52
#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
53
+ #include " jaxlib/mosaic/dialect/tpu/util.h"
54
54
#include " jaxlib/mosaic/dialect/tpu/vreg_util.h"
55
55
56
56
namespace mlir ::tpu {
@@ -601,14 +601,10 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
601
601
return op.emitOpError (" Vector/scalar mismatch between input and output" );
602
602
}
603
603
bool is_vector = static_cast <bool >(src_vty);
604
- unsigned src_bitwidth, dst_bitwidth;
605
- if (is_vector) {
606
- src_bitwidth = src_vty.getElementTypeBitWidth ();
607
- dst_bitwidth = dst_vty.getElementTypeBitWidth ();
608
- } else {
609
- src_bitwidth = op.getIn ().getType ().getIntOrFloatBitWidth ();
610
- dst_bitwidth = op.getType ().getIntOrFloatBitWidth ();
611
- }
604
+ FAILUREOR_ASSIGN_OR_RETURN (const unsigned src_bitwidth,
605
+ getElementTypeBitwidth (op.getIn ().getType ()));
606
+ FAILUREOR_ASSIGN_OR_RETURN (const unsigned dst_bitwidth,
607
+ getElementTypeBitwidth (op.getType ()));
612
608
if (dst_bitwidth > 32 ) {
613
609
return op.emitOpError (" Target bitwidth too large" );
614
610
}
@@ -623,6 +619,14 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
623
619
op.erase ();
624
620
return success ();
625
621
}
622
+
623
+ if ((src_bitwidth < 32 || dst_bitwidth < 32 ) && !ctx.compatibility_mode ) {
624
+ return op.emitOpError (
625
+ " On this target float-to-integer conversions can only happen on "
626
+ " 32-bit values. Enable compatibility mode or upcast to float32, cast "
627
+ " to int32 and truncate to desired bitwidth." );
628
+ }
629
+
626
630
Value x = op.getIn ();
627
631
// Upcast the input to f32.
628
632
if (src_bitwidth < 32 ) {
@@ -634,11 +638,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
634
638
}
635
639
}
636
640
if (dst_bitwidth < 32 ) {
637
- if (!ctx.compatibility_mode ) {
638
- return op.emitOpError (
639
- " On this target only float-to-integer conversions can only happen on "
640
- " 32-bit values. Enable compatibility mode or upcast to float32." );
641
- }
642
641
// Need to clip values to match XLA
643
642
auto clip = [&](Value x, Value low, Value high) {
644
643
x = builder.create <arith::MaximumFOp>(x, low);
@@ -666,19 +665,59 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
666
665
x = builder.create <arith::FPToSIOp>(builder.getI32Type (), x);
667
666
}
668
667
if (dst_bitwidth < 32 ) {
669
- if (!ctx.compatibility_mode ) {
670
- return op.emitOpError (
671
- " On this target only float-to-integer conversions can only happen on "
672
- " 32-bit values. Enable compatibility mode or cast to int32 and "
673
- " truncate later." );
674
- }
675
668
x = builder.create <arith::TruncIOp>(op.getType (), x);
676
669
}
677
670
op.replaceAllUsesWith (x);
678
671
op.erase ();
679
672
return success ();
680
673
}
681
674
675
+ LogicalResult canonicalize_sitofp (const CanonicalizeContext &ctx,
676
+ Operation &raw_op) {
677
+ auto op = cast<arith::FPToSIOp>(raw_op);
678
+ ImplicitLocOpBuilder builder (op->getLoc (), op.getOperation ());
679
+ auto src_vty = dyn_cast<VectorType>(op.getIn ().getType ());
680
+ auto dst_vty = dyn_cast<VectorType>(op.getType ());
681
+ if (static_cast <bool >(src_vty) != static_cast <bool >(dst_vty)) {
682
+ return op.emitOpError (" Vector/scalar mismatch between input and output" );
683
+ }
684
+ bool is_vector = static_cast <bool >(src_vty);
685
+ FAILUREOR_ASSIGN_OR_RETURN (const unsigned src_bitwidth,
686
+ getElementTypeBitwidth (op.getIn ().getType ()));
687
+ FAILUREOR_ASSIGN_OR_RETURN (const unsigned dst_bitwidth,
688
+ getElementTypeBitwidth (op.getType ()));
689
+
690
+ if ((src_bitwidth < 32 || dst_bitwidth < 32 ) && !ctx.compatibility_mode ) {
691
+ return op.emitOpError (
692
+ " On this target integer-to-float conversions can only happen on "
693
+ " 32-bit values. Enable compatibility mode or upcast to int32, cast to "
694
+ " float32 and truncate to desired bitwidth." );
695
+ }
696
+
697
+ // Canonicalize (intX -> floatY) to (intX -> int32 -> float32 -> floatY).
698
+ Value x = op.getIn ();
699
+ if (src_bitwidth < 32 ) {
700
+ if (is_vector) {
701
+ x = builder.create <arith::ExtSIOp>(
702
+ VectorType::get (src_vty.getShape (), builder.getI32Type ()), x);
703
+ } else {
704
+ x = builder.create <arith::ExtSIOp>(builder.getI32Type (), x);
705
+ }
706
+ }
707
+ if (is_vector) {
708
+ x = builder.create <arith::SIToFPOp>(
709
+ VectorType::get (src_vty.getShape (), builder.getF32Type ()), x);
710
+ } else {
711
+ x = builder.create <arith::SIToFPOp>(builder.getF32Type (), x);
712
+ }
713
+ if (dst_bitwidth < 32 ) {
714
+ x = builder.create <arith::TruncFOp>(op.getType (), x);
715
+ }
716
+ op.replaceAllUsesWith (x);
717
+ op.erase ();
718
+ return success ();
719
+ }
720
+
682
721
LogicalResult canonicalize_repeat (const CanonicalizeContext &ctx,
683
722
Operation &raw_op) {
684
723
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
@@ -727,6 +766,7 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
727
766
{vector::TransposeOp::getOperationName (), canonicalize_vector_transpose},
728
767
{arith::SelectOp::getOperationName (), canonicalize_select},
729
768
{arith::FPToSIOp::getOperationName (), canonicalize_fptosi},
769
+ {arith::SIToFPOp::getOperationName (), canonicalize_sitofp},
730
770
{tpu::RepeatOp::getOperationName (), canonicalize_repeat}};
731
771
return *rules;
732
772
}
0 commit comments