Skip to content

support the merging of two onnx model #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,27 @@ python entry.py

> 当前版本已支持从numpy文件中读取initializer数据。点击“Open *.npy”按钮,在弹出的对话框中选择numpy文件,数据便会自动解析并呈现在上方的输入框中,也支持在读取的数据基础上进一步编辑。

## 合并两个onnx模型

1. 初始页面加载两个模型
将两个模型文件拖动到页面,或在模型选择页面选择两个模型。等待加载完成。模型会并列显示。
修改完成后点击`download`下载。

2. 在操作页面进行合并
点击load model mode:下方的下拉框,可以选择`new`,`tail`,`parallel`三种模式。
`new`会打开新的模型;`tail`会将新模型追加到上一个模型的末尾;`parallel`会将新模型和旧模型并列。
拖动或点击按钮选择模型后会显示合并后的模型。
修改完成后点击`download`下载。

> **已知问题:**
> 1. 只能复制已知节点,未知节点需要先修改`metadata.json`,即先变为已知节点。
> 2. 调试模式下按下Alt,进入断点或出现网页跳转后,松开Alt,返回该网页后Alt会被认为一直是按下状态。需要重新按一次解除。
> 3. 合并模型暂时不支持回退。

## 选择多个连续的节点,并执行复制或删除操作

按住`Alt`键,点击第一个节点,`Alt`不松手点击第二个节点,完成选择。按下`j`键并松开,实现复制。按下`l`键并松开,实现删除。松开`Alt`取消选中状态。

<img src="./docs/edit_initializer_from_npy.png" style="zoom:50%;" />

`onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~
Expand Down
37 changes: 36 additions & 1 deletion onnx_modifier/flask_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
from flask import Flask, render_template, request
import time
from flask import Flask, render_template, request, send_file
from .onnx_modifier import onnxModifier
logging.basicConfig(level=logging.INFO)

Expand All @@ -22,6 +23,40 @@ def open_model():

return 'OK', 200


@app.route('/merge_model', methods=['POST'])
def merge_model():

onnx_file1 = request.files['file0']
onnx_file2 = request.files['file1']
timestamp = time.time()
global onnx_modifier
onnx_modifier, stream ,merged_name = onnxModifier.merge(
onnx_file1.filename, onnx_file1.stream,
onnx_file2.filename, onnx_file2.stream,
"", str(int(timestamp)) + "_")

return send_file(stream,
mimetype='application/octet-stream',
as_attachment=True,
download_name=merged_name)

@app.route('/append_model', methods=['POST'])
def append_model():
method = request.form.get('method')
onnx_file1 = request.files['file']
timestamp = time.time()
global onnx_modifier
if onnx_modifier:
onnx_modifier, stream ,merged_name = onnx_modifier.append(
onnx_file1.filename, onnx_file1.stream,
str(int(timestamp)) + "_", int(method))

return send_file(stream,
mimetype='application/octet-stream',
as_attachment=True,
download_name=merged_name)

@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
Expand Down
78 changes: 77 additions & 1 deletion onnx_modifier/onnx_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://github.com/saurabh-shandilya/onnx-utils
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model

from io import BytesIO
import os
import copy
import struct
Expand Down Expand Up @@ -39,13 +40,87 @@ def from_name_json_stream(cls, name, stream):
return cls(name, model_proto)

@classmethod
def from_name_protobuf_stream(cls, name, stream):
def from_name_protobuf_stream(cls, name, stream, prefix = ""):
# https://leimao.github.io/blog/ONNX-IO-Stream/
logging.info("loading model...")
stream.seek(0)
model_proto = onnx.load_model(stream, "protobuf", load_external_data=False)
model_proto = onnx.compose.add_prefix(model_proto, prefix=prefix)
logging.info("load done!")
return cls(name, model_proto)

@classmethod
def merge(cls, name1, stream1, name2, stream2, prefix1, prefix2):
stream1.seek(0)
model_proto1 = onnx.load_model(stream1, "protobuf", load_external_data=False)
model_proto1 = onnx.compose.add_prefix(model_proto1, prefix=prefix1)

stream2.seek(0)
model_proto2 = onnx.load_model(stream2, "protobuf", load_external_data=False)
model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix2)

model_proto1.graph.input.extend(model_proto2.graph.input)
model_proto1.graph.node.extend(model_proto2.graph.node)
model_proto1.graph.initializer.extend(model_proto2.graph.initializer)
model_proto1.graph.output.extend(model_proto2.graph.output)

merged_name = name1.split('.')[0] + "_" + name2.split('.')[0] + ".onnx"
byte_stream = BytesIO()
onnx.save_model(model_proto1, byte_stream)
byte_stream.seek(0)
return cls(merged_name, model_proto1), byte_stream, merged_name


def find_next_node_by_input(self, model, input_name):
first_node = None
for node in model.graph.node:
if input_name in node.input:
first_node = node
break
return first_node

def find_previous_node_by_output(self, model, output_name):
last_node = None
for node in model.graph.node:
if output_name in node.output:
last_node = node
return last_node

def append(self, name, stream, prefix, method = 1):

stream.seek(0)
model_proto2 = onnx.load_model(stream, "protobuf", load_external_data=False)
model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix)

model_proto1 = self.model_proto
if method == 1:
model1_last_node = self.find_previous_node_by_output(model_proto1, model_proto1.graph.output[-1].name)
model2_first_node = self.find_next_node_by_input(model_proto2, model_proto2.graph.input[0].name)

model1_last_node_output_name = model1_last_node.output[0]
model2_first_node_input_name = model2_first_node.input[0]

model2_first_node.input.remove(model2_first_node_input_name)
model2_first_node.input.insert(0, model1_last_node_output_name)

del model_proto1.graph.output[-1]

if not self.find_next_node_by_input(model_proto2, model_proto2.graph.input[0].name):
del model_proto2.graph.input[0]

model_proto1.graph.input.extend(model_proto2.graph.input)
model_proto1.graph.node.extend(model_proto2.graph.node)
model_proto1.graph.initializer.extend(model_proto2.graph.initializer)
model_proto1.graph.output.extend(model_proto2.graph.output)

merged_name = self.model_name.split('.')[0] + "_" + name.split('.')[0] + ".onnx"

onnx_mdf = onnxModifier(merged_name, model_proto1)
byte_stream = BytesIO()
onnx.save_model(onnx_mdf.model_proto, byte_stream)
byte_stream.seek(0)

return onnx_mdf, byte_stream, merged_name

def reload(self):
self.model_proto = copy.deepcopy(self.model_proto_backup)
Expand Down Expand Up @@ -460,6 +535,7 @@ def modify(self, modify_info):
logging.debug("=== modify_info ===\n", modify_info)

self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
self.change_initializer(modify_info['added_tensor'])
self.change_initializer(modify_info['changed_initializer'])
self.change_node_io_name(modify_info['node_renamed_io'])
self.edit_inputs(modify_info['added_inputs'], modify_info['rebatch_info'])
Expand Down
150 changes: 100 additions & 50 deletions onnx_modifier/static/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,89 @@ host.BrowserHost = class {
});
}

// open the the model at flask server and reload
// open mode 0: open new model, 1: add to tail, 2: parallel
remote_open_model(obj_files, append_model = 0) {
var files = Array.from(obj_files);
const acceptd_file = files.filter(file => this._view.accept(file.name)).slice(0, 2);
// console.log(file)
if(acceptd_file.length == 1 && append_model == 0)
{
var file = acceptd_file[0];
this.upload_filename = file.name;
var form = new FormData();
form.append('file', file);

// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
fetch('/open_model', {
method: 'POST',
body: form
}).then(function (response) {
return response.text();
}).then(function (text) {
console.log('POST response: ');
// Should be 'OK' if everything was successful
console.log(text);
});


if (file) {
this._open(file, files);
this._view.modifier.clearGraph();
}
} else if (acceptd_file.length == 2 || append_model != 0) {
var form = new FormData();
var url;
if (append_model == 0) {
url = '/merge_model';
for(var i = 0; i < acceptd_file.length; i++)
{
form.append('file' + i, acceptd_file[i]);
}
}
else if (append_model != 0) {
url = '/append_model';
form.append('file', acceptd_file[0]);
}
form.append('method', Number(append_model));
// console.log(file)
// this.upload_filename = file.name;
let filename = 'unknown';
// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
this._view.showLoading();
fetch(url, {
method: 'POST',
body: form
}).then((response) =>{

const contentDisposition = response.headers.get('content-disposition');
if (contentDisposition) {
const matches = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/i);
if (matches && matches[1]) {
filename = decodeURIComponent(matches[1].replace(/['"]/g, ''));
}
}
if (!response.ok) {
this._view.hideLoading();
throw new Error(`HTTP error! status: ${response.status}`);
}
var blob = response.blob();
return blob;
}).then(blob => {
var file = new File([blob], filename);
// console.log('POST response: ');
// // Should be 'OK' if everything was successful
// console.log(text);
if (file) {
this.upload_filename = file.name;
files = [];
files.push(file);
this._open(file, files);
this._view.modifier.clearGraph();
}
});
}
}

start() {
this.window.addEventListener('error', (e) => {
Expand Down Expand Up @@ -328,40 +410,26 @@ host.BrowserHost = class {
});
openFileDialog.addEventListener('change', (e) => {
if (e.target && e.target.files && e.target.files.length > 0) {
const files = Array.from(e.target.files);
const file = files.find((file) => this._view.accept(file.name));
// console.log(file)
this.upload_filename = file.name;
var form = new FormData();
form.append('file', file);

// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
fetch('/open_model', {
method: 'POST',
body: form
}).then(function (response) {
return response.text();
}).then(function (text) {
console.log('POST response: ');
// Should be 'OK' if everything was successful
console.log(text);
});


if (file) {
this._open(file, files);
this._view.modifier.clearGraph();
}
this.remote_open_model(e.target.files, this.loadModelMode);
}
});

}
const openModelButton = this.document.getElementById('load-model');
if (openModelButton && openFileDialog) {
openModelButton.addEventListener('click', () => {
openFileDialog.value = '';
openFileDialog.click();

var loadModelDropDown = this.document.getElementById('load-model-dropdown');
if (loadModelDropDown) {
loadModelDropDown.addEventListener('change', (e) => {
this.loadModelMode = loadModelDropDown.selectedIndex;
});
}

// const openModelButton = this.document.getElementById('load-model');
// if (openModelButton && openFileDialog) {
// openModelButton.addEventListener('click', () => {
// openFileDialog.value = '';
// openFileDialog.click();
// });
// }
const githubButton = this.document.getElementById('github-button');
const githubLink = this.document.getElementById('logo-github');
if (githubButton && githubLink) {
Expand All @@ -379,27 +447,8 @@ host.BrowserHost = class {
this.document.body.addEventListener('drop', (e) => {
e.preventDefault();
if (e.dataTransfer && e.dataTransfer.files && e.dataTransfer.files.length > 0) {
const files = Array.from(e.dataTransfer.files);
const file = files.find((file) => this._view.accept(file.name));
this.upload_filename = file.name;
var form = new FormData();
form.append('file', file);

// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
fetch('/open_model', {
method: 'POST',
body: form
}).then(function (response) {
return response.text();
}).then(function (text) {
console.log('POST response: ');
// Should be 'OK' if everything was successful
console.log(text);
});
if (file) {
this._open(file, files);
this._view.modifier.clearGraph();
}
this.remote_open_model(e.dataTransfer.files, this.loadModelMode);

}
});

Expand Down Expand Up @@ -588,6 +637,7 @@ host.BrowserHost = class {
// 'modified_inputs_info' : this.arrayToObject(this.process_modified_inputs(this._view.modifier.inputModificationForSave,
// this._view.modifier.renameMap, this._view.modifier.name2NodeStates)),
'rebatch_info' : this.mapToObjectRec(this._view.modifier.reBatchInfo),
'added_tensor' : this.mapToObjectRec(this._view.modifier.addedTensor),
'changed_initializer' : this.mapToObjectRec(this._view.modifier.initializerEditInfo),
'postprocess_args' : {'shapeInf' : this._view.modifier.downloadWithShapeInf, 'cleanUp' : this._view.modifier.downloadWithCleanUp}
})
Expand Down
Loading