Skip to content

Commit 18ed1bb

Browse files
[add] merge two onnx and related
1 parent 067010f commit 18ed1bb

11 files changed

+622
-111
lines changed

README_zh-CN.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,19 @@
182182

183183
`onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~
184184

185+
## 合并两个onnx模型
186+
187+
将两个模型文件拖动到页面,或在模型选择页面选择两个模型。等待加载...完成。
188+
修改完成后点击`download`下载。
189+
190+
> **已知问题:**
191+
> 1. 只能复制已知节点,未知节点需要先修改`metadata.json`,即先变为已知节点。
192+
> 2. 调试模式下按下Alt,进入断点并松开Alt,退出调试后,Alt会被认为一直是按下状态。需要重新按一次解除。
193+
194+
## 复制或删除多个节点
195+
196+
按住`Alt`键,点击第一个节点,`Alt`不松手点击第二个节点,完成选择。按下`j`键并松开,实现复制。按下`l`键并松开,实现删除。松开`Alt`取消选中状态。
197+
185198
# 示例模型文件
186199

187200
为方便测试,以下提供一些典型的样例模型文件,主要来自于[onnx model zoo](https://github.com/onnx/models)

app.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import logging
3-
from flask import Flask, render_template, request
3+
import time
4+
from flask import Flask, render_template, request, send_file
45
from onnx_modifier import onnxModifier
56
logging.basicConfig(level=logging.INFO)
67

@@ -24,6 +25,23 @@ def open_model():
2425

2526
return 'OK', 200
2627

28+
@app.route('/merge_model', methods=['POST'])
29+
def merge_model():
30+
31+
onnx_file1 = request.files['file0']
32+
onnx_file2 = request.files['file1']
33+
timestamp = time.time()
34+
global onnx_modifier
35+
onnx_modifier, stream ,merged_name = onnxModifier.merge(
36+
onnx_file1.filename, onnx_file1.stream,
37+
onnx_file2.filename, onnx_file2.stream,
38+
"", str(int(timestamp)) + "_")
39+
40+
return send_file(stream,
41+
mimetype='application/octet-stream',
42+
as_attachment=True,
43+
download_name=merged_name)
44+
2745

2846
@app.route('/download', methods=['POST'])
2947
def modify_and_download_model():

onnx_modifier.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/saurabh-shandilya/onnx-utils
44
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
55

6+
from io import BytesIO
67
import os
78
import copy
89
import struct
@@ -39,13 +40,36 @@ def from_name_json_stream(cls, name, stream):
3940
return cls(name, model_proto)
4041

4142
@classmethod
42-
def from_name_protobuf_stream(cls, name, stream):
43+
def from_name_protobuf_stream(cls, name, stream, prefix = ""):
4344
# https://leimao.github.io/blog/ONNX-IO-Stream/
4445
logging.info("loading model...")
4546
stream.seek(0)
4647
model_proto = onnx.load_model(stream, "protobuf", load_external_data=False)
48+
model_proto = onnx.compose.add_prefix(model_proto, prefix=prefix)
4749
logging.info("load done!")
4850
return cls(name, model_proto)
51+
52+
@classmethod
53+
def merge(cls, name1, stream1, name2, stream2, prefix1, prefix2):
54+
stream1.seek(0)
55+
model_proto1 = onnx.load_model(stream1, "protobuf", load_external_data=False)
56+
model_proto1 = onnx.compose.add_prefix(model_proto1, prefix=prefix1)
57+
58+
stream2.seek(0)
59+
model_proto2 = onnx.load_model(stream2, "protobuf", load_external_data=False)
60+
model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix2)
61+
62+
model_proto1.graph.input.extend(model_proto2.graph.input)
63+
model_proto1.graph.node.extend(model_proto2.graph.node)
64+
model_proto1.graph.initializer.extend(model_proto2.graph.initializer)
65+
model_proto1.graph.output.extend(model_proto2.graph.output)
66+
67+
merged_name = name1.split('.')[0] + "_" + name2.split('.')[0] + ".onnx"
68+
byte_stream = BytesIO()
69+
onnx.save_model(model_proto1, byte_stream)
70+
byte_stream.seek(0)
71+
return cls(merged_name, model_proto1), byte_stream, merged_name
72+
4973

5074
def reload(self):
5175
self.model_proto = copy.deepcopy(self.model_proto_backup)
@@ -460,6 +484,7 @@ def modify(self, modify_info):
460484
logging.debug("=== modify_info ===\n", modify_info)
461485

462486
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
487+
self.change_initializer(modify_info['added_tensor'])
463488
self.change_initializer(modify_info['changed_initializer'])
464489
self.change_node_io_name(modify_info['node_renamed_io'])
465490
self.edit_inputs(modify_info['added_inputs'], modify_info['rebatch_info'])

static/index.js

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,80 @@ host.BrowserHost = class {
116116
});
117117
}
118118

119+
// open the the model at flask server and reload
120+
remote_open_model(obj_files) {
121+
var files = Array.from(obj_files);
122+
const acceptd_file = files.filter(file => this._view.accept(file.name)).slice(0, 2);
123+
// console.log(file)
124+
if(acceptd_file.length == 1)
125+
{
126+
var file = acceptd_file[0];
127+
this.upload_filename = file.name;
128+
var form = new FormData();
129+
form.append('file', file);
130+
131+
// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
132+
fetch('/open_model', {
133+
method: 'POST',
134+
body: form
135+
}).then(function (response) {
136+
return response.text();
137+
}).then(function (text) {
138+
console.log('POST response: ');
139+
// Should be 'OK' if everything was successful
140+
console.log(text);
141+
});
119142

120143

144+
if (file) {
145+
this._open(file, files);
146+
this._view.modifier.clearGraph();
147+
}
148+
} else if (acceptd_file.length == 2)
149+
{
150+
var form = new FormData();
151+
for(var i = 0; i < acceptd_file.length; i++)
152+
{
153+
form.append('file'+i, acceptd_file[i]);
154+
}
155+
// console.log(file)
156+
// this.upload_filename = file.name;
157+
let filename = 'unknown';
158+
// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
159+
this._view.showLoading();
160+
fetch('/merge_model', {
161+
method: 'POST',
162+
body: form
163+
}).then(function (response) {
164+
165+
const contentDisposition = response.headers.get('content-disposition');
166+
if (contentDisposition) {
167+
const matches = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/i);
168+
if (matches && matches[1]) {
169+
filename = decodeURIComponent(matches[1].replace(/['"]/g, ''));
170+
}
171+
}
172+
if (!response.ok) {
173+
throw new Error(`HTTP error! status: ${response.status}`);
174+
}
175+
var blob = response.blob();
176+
return blob;
177+
}).then(blob => {
178+
var file = new File([blob], filename);
179+
// console.log('POST response: ');
180+
// // Should be 'OK' if everything was successful
181+
// console.log(text);
182+
if (file) {
183+
this.upload_filename = file.name;
184+
files = [];
185+
files.push(file);
186+
this._open(file, files);
187+
this._view.modifier.clearGraph();
188+
}
189+
});
190+
}
191+
}
192+
121193
start() {
122194
this.window.addEventListener('error', (e) => {
123195
this.exception(e.error, true);
@@ -328,33 +400,14 @@ host.BrowserHost = class {
328400
});
329401
openFileDialog.addEventListener('change', (e) => {
330402
if (e.target && e.target.files && e.target.files.length > 0) {
331-
const files = Array.from(e.target.files);
332-
const file = files.find((file) => this._view.accept(file.name));
333-
// console.log(file)
334-
this.upload_filename = file.name;
335-
var form = new FormData();
336-
form.append('file', file);
337-
338-
// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
339-
fetch('/open_model', {
340-
method: 'POST',
341-
body: form
342-
}).then(function (response) {
343-
return response.text();
344-
}).then(function (text) {
345-
console.log('POST response: ');
346-
// Should be 'OK' if everything was successful
347-
console.log(text);
348-
});
349-
350-
351-
if (file) {
352-
this._open(file, files);
353-
this._view.modifier.clearGraph();
354-
}
403+
this.remote_open_model(e.target.files);
355404
}
356405
});
406+
357407
}
408+
409+
410+
358411
const openModelButton = this.document.getElementById('load-model');
359412
if (openModelButton && openFileDialog) {
360413
openModelButton.addEventListener('click', () => {
@@ -379,27 +432,8 @@ host.BrowserHost = class {
379432
this.document.body.addEventListener('drop', (e) => {
380433
e.preventDefault();
381434
if (e.dataTransfer && e.dataTransfer.files && e.dataTransfer.files.length > 0) {
382-
const files = Array.from(e.dataTransfer.files);
383-
const file = files.find((file) => this._view.accept(file.name));
384-
this.upload_filename = file.name;
385-
var form = new FormData();
386-
form.append('file', file);
387-
388-
// https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
389-
fetch('/open_model', {
390-
method: 'POST',
391-
body: form
392-
}).then(function (response) {
393-
return response.text();
394-
}).then(function (text) {
395-
console.log('POST response: ');
396-
// Should be 'OK' if everything was successful
397-
console.log(text);
398-
});
399-
if (file) {
400-
this._open(file, files);
401-
this._view.modifier.clearGraph();
402-
}
435+
this.remote_open_model(e.dataTransfer.files);
436+
403437
}
404438
});
405439

@@ -588,6 +622,7 @@ host.BrowserHost = class {
588622
// 'modified_inputs_info' : this.arrayToObject(this.process_modified_inputs(this._view.modifier.inputModificationForSave,
589623
// this._view.modifier.renameMap, this._view.modifier.name2NodeStates)),
590624
'rebatch_info' : this.mapToObjectRec(this._view.modifier.reBatchInfo),
625+
'added_tensor' : this.mapToObjectRec(this._view.modifier.addedTensor),
591626
'changed_initializer' : this.mapToObjectRec(this._view.modifier.initializerEditInfo),
592627
'postprocess_args' : {'shapeInf' : this._view.modifier.downloadWithShapeInf, 'cleanUp' : this._view.modifier.downloadWithCleanUp}
593628
})

static/modifier.js

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ modifier.Modifier = class {
2323
this.downloadWithShapeInf = false;
2424
this.downloadWithCleanUp = false;
2525

26+
this.addedTensor = new Map();
27+
2628
}
2729

2830
loadModelGraph(model, graphs) {
@@ -79,13 +81,13 @@ modifier.Modifier = class {
7981
}
8082

8183

82-
try_get_node_name(op_type)
84+
try_get_node_name(op_type, input_node_id)
8385
{
84-
var node_id = (this.addNodeKey++).toString(); // in case input (onnx) node has no name
86+
var node_id = (input_node_id || this.addNodeKey++).toString(); // in case input (onnx) node has no name
8587
var modelNodeName = 'custom_added_' + op_type + node_id;
8688

8789
if (this.addedNode.has(modelNodeName) || this.name2NodeStates.get(modelNodeName) ){
88-
modelNodeName = this.randomString(16, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ');
90+
modelNodeName = try_get_node_name(op_type, Date.parse(new Date()));//this.randomString(16, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ');
8991
}
9092
return modelNodeName;
9193
}
@@ -103,6 +105,69 @@ modifier.Modifier = class {
103105
this.applyAndUpdateView();
104106
}
105107

108+
// duplicate a node with ( _cp + unique_id ) as param suffix
109+
duplicateNode(node_name, unique_id = "") {
110+
//avoid to add a existed name node
111+
var srcModelNode = this.name2ModelNode.get(node_name);
112+
if (!srcModelNode.type){
113+
return;
114+
}
115+
var dstModelNodeName = this.try_get_node_name(srcModelNode.type.name);
116+
117+
var properties = new Map();
118+
properties.set('domain', "ai.onnx");
119+
properties.set('op_type', srcModelNode.type.name);
120+
properties.set('name', dstModelNodeName);
121+
122+
123+
var attributes = new Map();
124+
for (const attribute of srcModelNode.attributes) {
125+
attributes.set(attribute.name, [
126+
attribute.value?attribute.value.toString():"undefined",
127+
attribute.type||"undefined"])
128+
// attributes.set(key, modelNode.attributes.get(key));
129+
}
130+
131+
var outputs = new Map();
132+
for (const output of srcModelNode.outputs) {
133+
var dstNameList = [];
134+
for (const srcArg of output.arguments) {
135+
var dstName = srcArg.name + "_cp" + unique_id;
136+
dstNameList.push([dstName, false]);
137+
this.graph.copy_tensor(dstName, srcArg.name);
138+
139+
}
140+
outputs.set(output.name, dstNameList);
141+
142+
}
143+
144+
var inputs = new Map();
145+
for (const input of srcModelNode.inputs) {
146+
var dstNameList = [];
147+
for (const srcArg of input.arguments) {
148+
var dstName = srcArg.name + "_cp" + unique_id;
149+
150+
if (this.graph._context._tensors &&
151+
this.graph._context._tensors.has(srcArg.name)) {
152+
this.graph.copy_tensor(dstName, srcArg.name);
153+
var initializer_info = this.graph.get_initializer_info(dstName);
154+
if(initializer_info) {
155+
this.addedTensor.set(dstName, initializer_info);
156+
}
157+
158+
}
159+
dstNameList.push([dstName, false])
160+
}
161+
inputs.set(input.name, dstNameList)
162+
163+
}
164+
165+
this.addedNode.set(dstModelNodeName,
166+
new view.LightNodeInfo(properties, attributes, inputs, outputs));
167+
this.applyAndUpdateView();
168+
}
169+
170+
106171
addModelOutput(node_name) {
107172
var modelNode = this.name2ModelNode.get(node_name);
108173
// use a output argument as a proxy
@@ -402,6 +467,10 @@ modifier.Modifier = class {
402467
for (const [modelNodeName, node_info] of this.addedNode) {
403468
// console.log(node_info)
404469
var node = this.graph.make_custom_added_node(node_info);
470+
if (!node) {
471+
console.log("node not supported yet");
472+
continue;
473+
}
405474
// console.log(node)
406475

407476
for (const input of node.inputs) {
@@ -489,6 +558,7 @@ modifier.Modifier = class {
489558
this.graph.reset_custom_added_node();
490559
this.graph.reset_custom_modified_outputs();
491560
this.graph.reset_custom_modified_inputs();
561+
this.graph.reset_custom_added_tensors();
492562
}
493563
// reset load location
494564
var container = this.view._getElementById('graph');

0 commit comments

Comments
 (0)