@@ -51,6 +51,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
5151    return  torch_upstream::ScalarType::Long;
5252  if  (type.isSignedInteger (32 ))
5353    return  torch_upstream::ScalarType::Int;
54+   if  (type.isSignedInteger (16 ))
55+     return  torch_upstream::ScalarType::Short;
5456  if  (type.isSignlessInteger (1 ))
5557    return  torch_upstream::ScalarType::Bool;
5658  if  (type.isBF16 ())
@@ -95,6 +97,8 @@ Torch::getTypeForScalarType(MLIRContext *context,
9597    return  IntegerType::get (context, 64 , mlir::IntegerType::Signed);
9698  case  torch_upstream::ScalarType::Int:
9799    return  IntegerType::get (context, 32 , mlir::IntegerType::Signed);
100+   case  torch_upstream::ScalarType::Short:
101+     return  IntegerType::get (context, 16 , mlir::IntegerType::Signed);
98102  case  torch_upstream::ScalarType::Bool:
99103    return  IntegerType::get (context, 1 );
100104  case  torch_upstream::ScalarType::BFloat16:
@@ -213,8 +217,8 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
213217                                               Location loc, float  value,
214218                                               Type dtype) {
215219  //  Creating constants satisfying backend contract.
216-   if  (dtype.isInteger (64 ) || dtype.isInteger (32 ) || dtype.isInteger (8 ) ||
217-       dtype.isInteger (1 ))
220+   if  (dtype.isInteger (64 ) || dtype.isInteger (32 ) || dtype.isInteger (16 ) ||
221+       dtype.isInteger (8 ) || dtype. isInteger ( 1 ))
218222    return  rewriter.create <ConstantIntOp>(
219223        loc, rewriter.getI64IntegerAttr ((int64_t )value));
220224  if  (dtype.isF64 () || dtype.isF32 () || dtype.isF16 () || dtype.isBF16 ())
0 commit comments