-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathConstantSubgraphAnalyser.h
125 lines (99 loc) · 4.11 KB
/
ConstantSubgraphAnalyser.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//===-- ConstantSubgraphAnalyser.h - Constant subgraph ----------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// This file implements constant subgraph analysis. In this file are:
/// 1. the lattice value class that represents operations with constant inputs
/// and outputs in the program, and
/// 2. a sparse constant subgraph analysis.
///
///===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H
#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
namespace mlir {
namespace dataflow {
//===----------------------------------------------------------------------===//
// IsConstantTensor
//===----------------------------------------------------------------------===//
/// This lattice represents a boolean indicating if a value is constant.
class IsConstantTensor {
public:
/// Construct as uninitialized.
explicit IsConstantTensor() = default;
/// Construct with a known state.
explicit IsConstantTensor(bool initialized, bool isConstantTensor)
: initialized(initialized), isConstantTensor(isConstantTensor) {}
/// Get the state. Must be initialized before.
bool getIsConstantTensor() const {
assert(!isUninitialized());
return isConstantTensor;
}
/// Compare.
bool operator==(const IsConstantTensor &rhs) const {
return initialized == rhs.initialized &&
isConstantTensor == rhs.isConstantTensor;
}
void print(raw_ostream &os) const;
/// Get uninitialized state. This happens when the
/// state hasn't been set during the analysis.
static IsConstantTensor getUninitialized() { return IsConstantTensor{}; }
/// Whether the state is uninitialized.
bool isUninitialized() const { return !initialized; }
/// Get unknown state.
static IsConstantTensor getUnknown() {
return IsConstantTensor{/*initialized=*/false,
/*isConstantTensor*/ false};
}
// Join two states.
static IsConstantTensor join(const IsConstantTensor &lhs,
const IsConstantTensor &rhs) {
// if one is uninitialized, use another
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
// both are initialized, intersect them
if (!lhs.isUninitialized() && !rhs.isUninitialized()) {
return IsConstantTensor(true, lhs.getIsConstantTensor() &&
rhs.getIsConstantTensor());
}
return getUninitialized();
}
private:
bool initialized = false;
bool isConstantTensor = false;
};
//===----------------------------------------------------------------------===//
// ConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//
class ConstantSubgraphAnalyser
: public SparseForwardDataFlowAnalysis<Lattice<IsConstantTensor>> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
LogicalResult visitOperation(Operation *op,
ArrayRef<const Lattice<IsConstantTensor> *> operands,
ArrayRef<Lattice<IsConstantTensor> *> results) override;
void setToEntryState(Lattice<IsConstantTensor> *lattice) override;
};
//===----------------------------------------------------------------------===//
// RunConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//
/// Runs constant subgraph analysis on the IR defined by `op`.
struct RunConstantSubgraphAnalyser {
public:
RunConstantSubgraphAnalyser();
void run(Operation *op);
bool getIsConstantTensor(Value val);
private:
/// Stores the result of the analysis.
DataFlowSolver solver;
void getConstantSubgraph(DataFlowSolver &solver, Operation *topFunc);
};
} // end namespace dataflow
} // end namespace mlir
#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H