Skip to content

Commit 8759641

Browse files
committed
Addressed the comment about vectors' shape
1 parent 803a91b commit 8759641

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,18 +1928,18 @@ class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
19281928

19291929
Value one = rewriter.create<Torch::ConstantIntOp>(
19301930
loc, rewriter.getI64IntegerAttr(1)); // Dimension index
1931-
SmallVector<int64_t, 2> inputMatrixShape = {inputShape[0], 1};
1931+
inputShape.push_back(1);
19321932
Type inputMatrixType = inputType.getWithSizesAndDtype(
1933-
inputMatrixShape, inputType.getOptionalDtype());
1933+
inputShape, inputType.getOptionalDtype());
19341934

19351935
Value inputMatrix =
19361936
rewriter.create<AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
19371937

19381938
Value zero = rewriter.create<Torch::ConstantIntOp>(
19391939
loc, rewriter.getI64IntegerAttr(0));
1940-
SmallVector<int64_t, 2> vec2MatrixShape = {1, vec2Shape[0]};
1941-
Type vec2MatrixType = vec2Type.getWithSizesAndDtype(
1942-
vec2MatrixShape, vec2Type.getOptionalDtype());
1940+
vec2Shape.insert(vec2Shape.begin(), 1);
1941+
Type vec2MatrixType =
1942+
vec2Type.getWithSizesAndDtype(vec2Shape, vec2Type.getOptionalDtype());
19431943

19441944
Value vec2Matrix =
19451945
rewriter.create<AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);

0 commit comments

Comments
 (0)