Skip to content

Commit 9a1c41d

Browse files
committed
support partial reduction
1 parent fb7aef4 commit 9a1c41d

File tree

6 files changed

+143
-107
lines changed

6 files changed

+143
-107
lines changed

include/gc/Dialect/Arith/Utils/EasyBuild.h

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
//===- EasyBuild.h - Easy Arith IR Builder utilities ------------*- C++ -*-===//
1+
//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
//
9-
// This header file defines the easy-build utilities for arith dialects. It
10-
// provides the utility functions, classes and operators to make it easir to
11-
// program arith dialect operations in C++
12-
//
13-
//===----------------------------------------------------------------------===//
14-
158
#ifndef MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H
169
#define MLIR_DIALECT_ARITH_UTILS_EASYBUILD_H
1710
#include "gc/IR/EasyBuild.h"
@@ -28,12 +21,8 @@ namespace impl {
2821

2922
template <std::size_t size> struct ToFloatType {};
3023

31-
template <> struct ToFloatType<4> {
32-
using type = Float32Type;
33-
};
34-
template <> struct ToFloatType<8> {
35-
using type = Float64Type;
36-
};
24+
template <> struct ToFloatType<4> { using type = Float32Type; };
25+
template <> struct ToFloatType<8> { using type = Float64Type; };
3726

3827
inline Type getElementType(Value v) {
3928
auto type = v.getType();

include/gc/IR/EasyBuild.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
//===- EasyBuild.h - Easy IR Builder utilities ------------------*- C++ -*-===//
1+
//===-- EasyBuild.h - DESC --------------------------------------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
//
9-
// This header file defines the easy-build utilities core data structures for
10-
// building IR.
11-
//
12-
//===----------------------------------------------------------------------===//
13-
148
#ifndef MLIR_IR_EASYBUILD_H
159
#define MLIR_IR_EASYBUILD_H
1610
#include "mlir/IR/Builders.h"

include/gc/IR/EasyBuildSCF.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
//===- EasyBuildSCF.h - Easy IR Builder for general control flow *- C++ -*-===//
1+
//===-- EasyBuildSCF.h - DESC -----------------------------------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
89
//
910
// This header file defines the helper classes, functions and macros to help to
1011
// build general structured control flow. Developers can use the utilities in

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
//===-- DeepTileContractionNamedOp.cpp - DESC -------------------*- C++ -*-===//
2-
//
2+
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
6+
//
77
//===----------------------------------------------------------------------===//
88

99
#include "./Tiling.hpp"
@@ -273,9 +273,19 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
273273
b.setInsertionPoint(currentOp);
274274
if (auto partialInterface =
275275
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) {
276+
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
277+
if (isConstantIntValue(tile, 0)) {
278+
tileSizes[idx] = loopRanges[idx].size;
279+
}
280+
}
281+
282+
SmallVector<OpFoldResult> newParallelDims;
283+
for (auto i = 0UL; i < reductionDims.size(); i++) {
284+
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
285+
}
276286
auto tilingResult = linalgX::tileAllUsingForall(
277-
b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
278-
numThreads, tileSizes, std::nullopt);
287+
b, cast<PartialReductionOpInterface>(currentOp.getOperation()), {},
288+
tileSizes, newParallelDims, std::nullopt);
279289
if (failed(tilingResult))
280290
return failure();
281291
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOp);

0 commit comments

Comments
 (0)