Skip to content

Commit 07c3e11

Browse files
[MLIR][TORCH] Add support for Short(si16) data type
Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent fb21a85 commit 07c3e11

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
5151
return torch_upstream::ScalarType::Long;
5252
if (type.isSignedInteger(32))
5353
return torch_upstream::ScalarType::Int;
54+
if (type.isSignedInteger(16))
55+
return torch_upstream::ScalarType::Short;
5456
if (type.isSignlessInteger(1))
5557
return torch_upstream::ScalarType::Bool;
5658
if (type.isBF16())
@@ -95,6 +97,8 @@ Torch::getTypeForScalarType(MLIRContext *context,
9597
return IntegerType::get(context, 64, mlir::IntegerType::Signed);
9698
case torch_upstream::ScalarType::Int:
9799
return IntegerType::get(context, 32, mlir::IntegerType::Signed);
100+
case torch_upstream::ScalarType::Short:
101+
return IntegerType::get(context, 16, mlir::IntegerType::Signed);
98102
case torch_upstream::ScalarType::Bool:
99103
return IntegerType::get(context, 1);
100104
case torch_upstream::ScalarType::BFloat16:
@@ -213,8 +217,8 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
213217
Location loc, float value,
214218
Type dtype) {
215219
// Creating constants satisfying backend contract.
216-
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) ||
217-
dtype.isInteger(1))
220+
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(16) ||
221+
dtype.isInteger(8) || dtype.isInteger(1))
218222
return rewriter.create<ConstantIntOp>(
219223
loc, rewriter.getI64IntegerAttr((int64_t)value));
220224
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())

0 commit comments

Comments
 (0)