Skip to content

Commit ad3914b

Browse files
committed
cat set env ONNX_PROVIDER , can print all/can use onnx providers
1 parent 99a8a4d commit ad3914b

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

onnx/README.md

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ ONNX(Open Neural Network Exchange),开放神经网络交换,用于在各种
5454

5555
### 目录
5656

57-
* model/ 存放下载的模型
58-
* onnx/ 存放导出的 onnx,下载的 onnx 也请解压到这里
57+
* `model/` 存放下载的模型
58+
* `onnx/` 存放导出的 onnx,下载的 onnx 也请解压到这里
5959

6060
### 测试
6161

@@ -80,9 +80,36 @@ onnxruntime 有很多版本可以选择,见[onnxruntime](https://onnxruntime.a
8080

8181
* [./test/onnx/onnx_img.py](./test/onnx/onnx_img.py) 生成图片向量 (norm 代表归一化的向量,可用于向量搜索)
8282
* [./test/onnx/onnx_txt.py](./test/onnx/onnx_txt.py) 生成文本向量
83-
* [./test/onnx/onnx_test.py](./test/onnx/onnx_test.py) 匹配图片向量和文本向量,进行零样本分类
83+
* [./test/onnx/onnx_test.py](./test/onnx/onnx_test.py)
8484

85-
可借助向量数据库,提升零样本分类的准确性,参见[ECCV 2022 | 无需下游训练,Tip-Adapter 大幅提升 CLIP 图像分类准确率](https://cloud.tencent.com/developer/article/2126102)
85+
匹配图片向量和文本向量,进行零样本分类
86+
87+
可借助向量数据库,提升零样本分类的准确性,参见[ECCV 2022 | 无需下游训练,Tip-Adapter 大幅提升 CLIP 图像分类准确率](https://cloud.tencent.com/developer/article/2126102)
88+
* [./test/onnx/onnx_load.py](./test/onnx/onnx_load.py)
89+
90+
onnx 模型的加载代码,运行它可以看到当前机器可用的 onnx provider。
91+
92+
比如苹果 M2 芯片的笔记本上运行如下:
93+
94+
```
95+
❯ ./onnx_load.py 2>/dev/null
96+
all providers :
97+
['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'OpenVINOExecutionProvider', 'DnnlExecutionProvider', 'TvmExecutionProvider', 'VitisAIExecutionProvider', 'NnapiExecutionProvider', 'CoreMLExecutionProvider', 'ArmNNExecutionProvider', 'ACLExecutionProvider', 'DmlExecutionProvider', 'RknpuExecutionProvider', 'XnnpackExecutionProvider', 'CANNExecutionProvider', 'CPUExecutionProvider']
98+
99+
now can use providers :
100+
['CoreMLExecutionProvider', 'CPUExecutionProvider']
101+
```
102+
103+
可以创建 FlagAI/onnx/.env ,设置环境变量 `ONNX_PROVIDER`,配置当前环境的 Onnx Execution Provider,方便测试对比性能。
104+
105+
设置的示例如下:
106+
107+
```
108+
❯ cat FlagAI/onnx/.env
109+
ONNX_PROVIDER=CoreMLExecutionProvider
110+
```
111+
112+
设置成功后,需要在 `FlagAI/onnx` 目录下运行 `direnv allow` 或者手工 `source .envrc` 让其在当前命令行中生效。
86113

87114
#### pytorch 模型
88115

onnx/test/onnx/onnx_img.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ def img2vec(img):
1616
from os.path import join
1717
img = Image.open(join(IMG_DIR, 'cat.jpg'))
1818

19+
img_data = transform(img)
20+
import torch
21+
print('img data size', torch.tensor(img_data).size())
22+
1923
vec = img2vec(img)
2024
print('vec', vec)
2125
IMG_NORM = onnx_load('ImgNorm')
22-
print('norm', IMG_NORM.run(None, {'input': transform(img)})[0])
26+
print('norm', IMG_NORM.run(None, {'input': img_data})[0])

onnx/test/onnx/onnx_load.py

100644100755
Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import onnxruntime
44
from misc.config import ONNX_FP
55
from os.path import join
6+
import os
67

78
session = onnxruntime.SessionOptions()
89
option = onnxruntime.RunOptions()
@@ -11,13 +12,17 @@
1112

1213
def onnx_load(kind):
1314
fp = join(ONNX_FP, f'{kind}.onnx')
14-
15+
provider = os.getenv('ONNX_PROVIDER')
16+
providers = [provider] if provider else None
1517
sess = onnxruntime.InferenceSession(fp,
1618
sess_options=session,
17-
providers=[
18-
'TensorrtExecutionProvider',
19-
'CUDAExecutionProvider',
20-
'CoreMLExecutionProvider',
21-
'CPUExecutionProvider'
22-
])
19+
providers=providers)
2320
return sess
21+
22+
23+
if __name__ == '__main__':
24+
from onnxruntime import get_all_providers
25+
print('all providers :\n%s\n' % get_all_providers())
26+
sess = onnx_load('Txt')
27+
providers = sess.get_providers()
28+
print('now can use providers :\n%s\n' % providers)

0 commit comments

Comments
 (0)