Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT] Support Multiple EP Context #23294

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,108 changes: 1,627 additions & 1,481 deletions docs/ContribOperators.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3365,6 +3365,7 @@ void RegisterContribSchemas() {
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
Expand Down
93 changes: 34 additions & 59 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
return false;
}

int FindCtxNodeInGraph(const GraphViewer& graph_viewer) {
// Assumes there's only 1 context node in this subgraph (graph_viewer)
// Returns index of node
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return i;
}
}
return -1;
}

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
Expand All @@ -40,38 +52,18 @@ const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
return main_graph.ModelPath();
}

/*
* Update ep_cache_context attribute of the EP context node with the given engine binary data
*/
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size) {
ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

for (int i = 0; i < node_proto->attribute_size(); ++i) {
ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
if (attribute_proto->name() == EP_CACHE_CONTEXT) {
std::string engine_data_str = "";
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attribute_proto->set_s(engine_data_str);
}
}
}

/*
* Create "EP context node" model where engine information is embedded
*/
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger) {
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();

Expand Down Expand Up @@ -123,18 +115,11 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3);

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
ORT_ENFORCE(graph_build.Resolve().IsOK());

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto& metadata = graph_viewer.GetGraph().GetModel().MetaData();
auto model = new_graph_viewer->CreateModel(*logger, metadata);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

return model_proto.release();
graph_build.AddNode(fused_subgraph_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
auto status = graph_build.Resolve();
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());

return model_build;
}

/*
Expand Down Expand Up @@ -203,17 +188,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path,
return ctx_model_path;
}

/*
* Dump "EP context" model
*
*/
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path) {
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
Expand Down Expand Up @@ -267,11 +241,11 @@ bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
return engine_cache_path.stem().extension().string() == ".stripped";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx) {
if (!ValidateEPCtxNode(graph_viewer, ctx_node_idx)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(0);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
Expand Down Expand Up @@ -381,14 +355,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
assert(graph_viewer.GetNode(ctx_node_idx)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

// Show the warning if compute capability is not matched
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
if (attrs.find(COMPUTE_CAPABILITY) != attrs.end() && attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
// Verify if engine was compiled with ampere+ hardware compatibility enabled
if (model_compute_capability == "80+") {
Expand All @@ -415,4 +389,5 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe

return true;
}

} // namespace onnxruntime
28 changes: 13 additions & 15 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,23 @@ static const std::string EPCONTEXT_WARNING =
for the best model loading time";

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
int FindCtxNodeInGraph(const GraphViewer& graph_viewer);

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path);
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger);
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger);
std::string GetCtxModelPath(const std::string& ep_context_file_path,
const std::string& original_model_path);
bool IsAbsolutePath(const std::string& path_string);
bool IsRelativePathToParentPath(const std::string& path_string);
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);

class TensorRTCacheModelHandler {
public:
Expand All @@ -67,9 +65,9 @@ class TensorRTCacheModelHandler {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);
bool ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);
Status GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
Expand Down
Loading
Loading