Skip to content

Commit

Permalink
update TRT EP
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Feb 9, 2025
1 parent df5aca9 commit 60d9599
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 49 deletions.
59 changes: 11 additions & 48 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2554,8 +2554,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
}

bool early_termination = false;
// supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
supported_nodes_vector = parser_nodes_vector;
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
if (early_termination) {
supported_nodes_vector.clear();
}
Expand Down Expand Up @@ -2660,13 +2659,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
* Enable EP related L2+ graph optimizations with steps:
*
* 1. Call provider bridge API to lookup pre-defined optimizer by name and get selection function
* - Run selection function to get selection ComputeCapability
* - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization
* 2. Run selection function to get selection ComputeCapability
- ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization
*
*
*
* Current available optimizations:
* - (ConstantFoldingDQ) constant folding on DQ nodes -> Dequantize INT32, UINT16, INT16 constant to FP32.
* - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32.
*/

std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&)> selection_func;
Expand All @@ -2687,52 +2686,16 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,

SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq);

// Include nodes that are filtered out by TRT parser.
auto update_supported_node_vector = [&](SubGraph_t& supported_node_vector, SubGraphCollection_t& supported_nodes_vector) -> void {
if (!consumer_to_dq.empty()) {
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1);
for (auto index : supported_node_vector.first) {
if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) {
continue;
}

auto dq_node_index = consumer_to_dq[node_index[index]];

// Check if DQ node is included in one of the subgraphs
auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool {
for (auto& node_vector : supported_nodes_vector) {
if (!node_vector.second) {
continue;
}
for (auto index : node_vector.first) {
if (node_index[index] == node_idx) {
return true;
}
}
}
return false;
};
if (in_the_subgraph_collection(dq_node_index)) {
continue;
}
// Find the iterator pointing to the target element
auto it = std::find(node_index.begin(), node_index.end(), dq_node_index);
if (it != node_index.end()) {
// Calculate the index
int idx = std::distance(node_index.begin(), it);
supported_node_vector.first.push_back(static_cast<NodeIndex>(idx));
auto node = graph.GetNode(dq_node_index);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser.";
}
}
}
};

// Create ComputeCapability
int number_of_trt_nodes = 0, subgraph_index = 0;
for (const auto& group : supported_nodes_vector) {
for (auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
// TODO: Use consumer_to_dq table to include DQ node that is filtered out by TRT parser.

if (!selection_cc.empty()) {
// Include DQ nodes that are filtered out by TRT parser
UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq);
}

std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index);
auto compute_capability = ComputeCapability::Create(std::move(sub_graph));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,5 +612,13 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unique_ptr<ComputeCapability> CreateOptimizationComputeCapability(ComputeCapability* selection_cc,
std::unordered_set<NodeIndex>& trt_selection_node_set,
ComputeCapability* trt_cc) const;
/**
* This function helps add back the DQ nodes that are filtered out by TRT parser.
* The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization.
*/
void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph,
SubGraph_t& supported_node_vector,
SubGraphCollection_t& supported_nodes_vector,
std::unordered_map<NodeIndex, NodeIndex> consumer_to_dq) const;
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph,
const Node& consumer_node = *node->OutputNodesBegin();
selection_node_set.insert(index);
consumer_to_dq[consumer_node.Index()] = index;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " < -" << node->Name();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name();
}
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected.";
Expand Down Expand Up @@ -330,4 +330,61 @@ std::unique_ptr<ComputeCapability> TensorrtExecutionProvider::CreateOptimization
compute_capability->copy_optimization_func(selection_cc);
return compute_capability;
}

/**
* This function helps add back the DQ nodes that are filtered out by TRT parser.
* The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization.
*/
void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph,
SubGraph_t& supported_node_vector,
SubGraphCollection_t& supported_nodes_vector,
std::unordered_map<NodeIndex, NodeIndex> consumer_to_dq) const {
if (consumer_to_dq.empty()) {
return;
}

if (!supported_node_vector.second) {
return;
}

const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1);
auto supported_nodes = supported_node_vector.first;
for (auto index : supported_nodes) {
if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) {
continue;
}

auto dq_node_index = consumer_to_dq[node_index[index]];

// Check if DQ node is included in one of the subgraphs
auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool {
for (auto& node_vector : supported_nodes_vector) {
if (!node_vector.second) {
continue;
}
for (auto i : node_vector.first) {
if (node_index[i] == node_idx) {
return true;
}
}
}
return false;
};

// If the DQ node is already in the subgraph, do nothing.
if (in_the_subgraph_collection(dq_node_index)) {
continue;
}

// Find the iterator pointing to the target element
auto it = std::find(node_index.begin(), node_index.end(), dq_node_index);
if (it != node_index.end()) {
// Calculate the index
int idx = std::distance(node_index.begin(), it);
supported_node_vector.first.push_back(static_cast<NodeIndex>(idx));
auto node = graph.GetNode(dq_node_index);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser.";
}
}
}
} // namespace onnxruntime

0 comments on commit 60d9599

Please sign in to comment.