-
Notifications
You must be signed in to change notification settings - Fork 361
/
Copy pathApp.tsx
81 lines (73 loc) · 2.57 KB
/
App.tsx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import { StatusBar } from 'expo-status-bar';
import { Alert, Button, StyleSheet, Text, View } from 'react-native';
import * as ort from 'onnxruntime-react-native';
import { Asset } from 'expo-asset';
// Note: These modules are used for reading model into bytes
// import RNFS from 'react-native-fs';
// import base64 from 'base64-js';
let myModel: ort.InferenceSession;
async function loadModel() {
try {
// Note: `.onnx` model files can be viewed in Netron (https://github.com/lutzroeder/netron) to see
// model inputs/outputs detail and data types, shapes of those, etc.
const assets = await Asset.loadAsync(require('./assets/mnist.onnx'));
const modelUri = assets[0].localUri;
if (!modelUri) {
Alert.alert('failed to get model URI', `${assets[0]}`);
} else {
// load model from model url path
myModel = await ort.InferenceSession.create(modelUri);
Alert.alert(
'model loaded successfully',
`input names: ${myModel.inputNames}, output names: ${myModel.outputNames}`);
// loading model from bytes
// const base64Str = await RNFS.readFile(modelUri, 'base64');
// const uint8Array = base64.toByteArray(base64Str);
// myModel = await ort.InferenceSession.create(uint8Array);
}
} catch (e) {
Alert.alert('failed to load model', `${e}`);
throw e;
}
}
async function runModel() {
try {
// Prepare model input data
// Note: In real use case, you must set the inputData to the actual input values
const inputData = new Float32Array(28 * 28);
const feeds:Record<string, ort.Tensor> = {};
feeds[myModel.inputNames[0]] = new ort.Tensor(inputData, [1, 1, 28, 28]);
// Run inference session
const fetches = await myModel.run(feeds);
// Process output
const output = fetches[myModel.outputNames[0]];
if (!output) {
Alert.alert('failed to get output', `${myModel.outputNames[0]}`);
} else {
Alert.alert(
'model inference successfully',
`output shape: ${output.dims}, output data: ${output.data}`);
}
} catch (e) {
Alert.alert('failed to inference model', `${e}`);
throw e;
}
}
export default function App() {
return (
<View style={styles.container}>
<Text>ONNX Runtime React Native Basic Usage</Text>
<Button title='Load model' onPress={loadModel}></Button>
<Button title='Run' onPress={runModel}></Button>
<StatusBar style="auto" />
</View>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
backgroundColor: '#fff',
alignItems: 'center',
justifyContent: 'center',
},
});