Skip to content

Commit 6da008b

Browse files
committed
load and run onnx
1 parent 6ffe9b3 commit 6da008b

File tree

6 files changed

+32
-4
lines changed

6 files changed

+32
-4
lines changed

src/main/java-ml/io/mapsmessaging/selector/operators/functions/ml/MLFunction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,12 @@ public static MLFunction parse(String functionName, List<String> list) {
8484
List<String> identifiers = new ArrayList<>();
8585
int startIdx;
8686

87-
if ("tensorFlow".equalsIgnoreCase(functionName) || list.size() <=1) {
87+
boolean isMl = ("tensorFlow".equalsIgnoreCase(functionName) ||
88+
"onnx".equalsIgnoreCase(functionName) ||
89+
list.size() <=1);
90+
91+
92+
if (isMl) {
8893
modelName = list.getFirst();
8994
startIdx = 1;
9095
} else {

src/main/java-ml/io/mapsmessaging/selector/operators/functions/ml/impl/functions/onnx/OnnxModelRegistry.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ public static OnnxModelEntry getOrLoad(final String modelName,
5555
return cached;
5656
}
5757

58-
byte[] modelBytes = java.util.Objects.requireNonNull(
59-
modelStore.loadModel(modelName), "Model bytes are null for " + modelName);
58+
byte[] modelBytes = modelStore.loadModel(modelName);
6059

6160
// Try engines in the precomputed order until a session is created.
6261
OrtSession ortSession = null;

src/main/java-ml/io/mapsmessaging/selector/operators/functions/ml/impl/functions/onnx/OnnxOperation.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,9 @@ private void warmUp() throws OrtException {
172172
}
173173
}
174174
}
175+
176+
@Override
177+
public String toString() {
178+
return "onnx ("+ super.toString() + ")";
179+
}
175180
}

src/main/java-ml/io/mapsmessaging/selector/operators/functions/ml/impl/functions/onnx/OnnxRuntimeGate.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public class OnnxRuntimeGate {
3232
OrtEnvironment.getEnvironment().close(); // no-op safe
3333
ok = true;
3434
} catch (Throwable t) {
35+
t.printStackTrace();
3536
ok = false; // UnsatisfiedLinkError, missing deps, etc.
3637
}
3738
AVAILABLE = ok && !Boolean.getBoolean("maps.ml.onnx.disabled")

src/test/java-ml/io/mapsmessaging/selector/SelectorMLConformanceTest.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@
2020

2121
package io.mapsmessaging.selector;
2222

23+
import io.mapsmessaging.selector.ml.impl.store.FileModelStore;
24+
import io.mapsmessaging.selector.model.ModelStore;
2325
import io.mapsmessaging.selector.operators.ParserExecutor;
26+
import io.mapsmessaging.selector.operators.functions.ml.MLFunction;
2427
import io.mapsmessaging.selector.operators.functions.ml.impl.functions.onnx.OnnxRuntimeGate;
28+
import org.junit.jupiter.api.AfterAll;
2529
import org.junit.jupiter.api.Assertions;
30+
import org.junit.jupiter.api.BeforeAll;
2631
import org.junit.jupiter.params.ParameterizedTest;
2732
import org.junit.jupiter.params.provider.MethodSource;
2833

@@ -77,10 +82,11 @@ public static Stream<String> selectors() {
7782
"pca_cor (explainedvariance[3], model_pca_cor.arff) > 0.7",
7883
"pca_cor (explainedvariance[4], model_pca_cor.arff, temp, humidity) > 0.7",
7984
"tensorflow (sensor_safety_model, temp, humidity, co2) < 10",
80-
"onnx (sensor_safety_model, temp, humidity, co2) < 10"
85+
"onnx (model.onnx, temp, humidity, co2) < 10"
8186

8287
/*
8388
89+
8490
"svm (classify, model_svm.arff) = 1",
8591
"svm (classify, model_svm.arff, temp, humidity) = 1",
8692
@@ -92,6 +98,18 @@ public static Stream<String> selectors() {
9298

9399
}
94100

101+
private static ModelStore init;
102+
@BeforeAll
103+
static void setup(){
104+
init = MLFunction.getModelStore();
105+
MLFunction.setModelStore(new FileModelStore("./src/test/resources/"));
106+
}
107+
108+
@AfterAll
109+
static void tearDown(){
110+
MLFunction.setModelStore(init);
111+
}
112+
95113
@ParameterizedTest(name = "Syntax test for: {0}")
96114
@MethodSource("selectors")
97115
void syntaxTest(String selector) {

src/test/resources/model.onnx

1.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)