From 99f612800517911a985cf7f97980707ba205c056 Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Wed, 27 Nov 2024 20:30:49 +0100 Subject: [PATCH 1/5] add ops registry --- candle-onnx/src/lib.rs | 2 + candle-onnx/src/ops/compute_node.rs | 30 +++++++ candle-onnx/src/ops/mod.rs | 5 ++ candle-onnx/src/ops/onnxop.rs | 126 ++++++++++++++++++++++++++++ 4 files changed, 163 insertions(+) create mode 100644 candle-onnx/src/ops/compute_node.rs create mode 100644 candle-onnx/src/ops/mod.rs create mode 100644 candle-onnx/src/ops/onnxop.rs diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs index efd6f7600f..da453801ad 100644 --- a/candle-onnx/src/lib.rs +++ b/candle-onnx/src/lib.rs @@ -6,6 +6,8 @@ pub mod onnx { } pub mod eval; +mod ops; + pub use eval::{dtype, simple_eval}; pub fn read_file>(p: P) -> Result { diff --git a/candle-onnx/src/ops/compute_node.rs b/candle-onnx/src/ops/compute_node.rs new file mode 100644 index 0000000000..2b9994ae77 --- /dev/null +++ b/candle-onnx/src/ops/compute_node.rs @@ -0,0 +1,30 @@ +use crate::onnx::NodeProto; +use candle::Tensor; +use std::collections::HashMap; + +//This struct is used to represent a node in the computation graph +//The idea is not to use the NodeProto directly in the computation graph +//On a longer term, this can lead to a more optimized representation of the computation graph. +//For now, it is just a wrapper around the NodeProto and the context +pub struct ComputeNode<'a> { + node_proto: &'a NodeProto, + context: &'a HashMap, +} + +impl<'a> ComputeNode<'a> { + pub fn new(node_proto: &'a NodeProto, context: &'a HashMap) -> Self { + ComputeNode { + node_proto, + context, + } + } + + pub fn get_input(&self, index: usize) -> Option<&Tensor> { + let input_name = self.node_proto.input.get(index)?; + self.context.get(input_name) + } + + pub fn get_output(&self, index: usize) -> Option<&String> { + self.node_proto.output.get(index) + } +} diff --git a/candle-onnx/src/ops/mod.rs b/candle-onnx/src/ops/mod.rs new file mode 100644 index 0000000000..d5f5676d02 --- /dev/null +++ b/candle-onnx/src/ops/mod.rs @@ -0,0 +1,5 @@ +pub mod onnxop; +pub use onnxop::{OnnxOp, OnnxOpError, OnnxOpRegistry, OpOutput}; + +pub mod compute_node; +pub use compute_node::ComputeNode; diff --git a/candle-onnx/src/ops/onnxop.rs b/candle-onnx/src/ops/onnxop.rs new file mode 100644 index 0000000000..8725186db0 --- /dev/null +++ b/candle-onnx/src/ops/onnxop.rs @@ -0,0 +1,126 @@ +use crate::ops::ComputeNode; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; + +pub type OpOutput = (String, candle::Tensor); + +#[derive(Debug, PartialEq, Eq)] +pub enum OnnxOpError { + InvalidInput(String), + ComputationFailed(String), + UnsupportedOp(String), + DuplicateOp(String), +} + +impl From for candle::Error { + fn from(e: OnnxOpError) -> Self { + candle::Error::Msg(format!("{:?}", e)) + } +} + +impl Display for OnnxOpError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + OnnxOpError::InvalidInput(s) => write!(f, "Invalid input: {}", s), + OnnxOpError::ComputationFailed(s) => write!(f, "Computation failed: {}", s), + OnnxOpError::UnsupportedOp(s) => write!(f, "Unsupported op: {}", s), + OnnxOpError::DuplicateOp(s) => write!(f, "Duplicate op: {}", s), + } + } +} + +pub trait OnnxOp { + fn eval(&self, node: &ComputeNode) -> Result; +} + +#[derive(Default)] +pub struct OnnxOpRegistry { + ops: HashMap>, +} + +impl OnnxOpRegistry { + pub fn new() -> Self { + Self { + ops: HashMap::new(), + } + } + pub fn insert(&mut self, name: &str, op: Box) -> Result<(), OnnxOpError> { + match self.ops.entry(name.to_string()) { + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(op); + Ok(()) + } + Entry::Occupied(_) => Err(OnnxOpError::DuplicateOp(name.to_string())), + } + } + + pub fn get(&self, name: &str) -> Result<&dyn OnnxOp, OnnxOpError> { + match self.ops.get(name) { + Some(op) => Ok(op.as_ref()), + None => Err(OnnxOpError::UnsupportedOp(name.to_string())), + } + } +} + +#[cfg(test)] +mod onnxop_registry_tests { + use super::*; + use candle::Device; + #[test] + fn nominal_case() { + //Given + let dummy_op = Box::new(DummyOp); + let mut registry = OnnxOpRegistry::new(); + + //When + registry.insert("DummyOp", dummy_op).unwrap(); + let op = registry.get("DummyOp"); + + //Then + assert!(op.is_ok()); + } + + #[test] + fn unsupported_op() { + //Given + let registry = OnnxOpRegistry::new(); + + //When + let op = registry.get("Foo"); + + //Then + match op { + Err(OnnxOpError::UnsupportedOp(_)) => {} + _ => panic!("Expected unsupported op error"), + } + } + + #[test] + fn duplicate_op() { + //Given + let dummy_op = Box::new(DummyOp); + let mut registry = OnnxOpRegistry::new(); + registry.insert("DummyOp", dummy_op).unwrap(); + + //When + let dummy_op = Box::new(DummyOp); + let result = registry.insert("DummyOp", dummy_op); + + //Then + match result { + Err(OnnxOpError::DuplicateOp(_)) => {} + _ => panic!("Expected duplicate op error"), + } + } + + struct DummyOp; + impl OnnxOp for DummyOp { + fn eval(&self, _node: &ComputeNode) -> Result { + Ok(( + "dummy".to_string(), + candle::Tensor::new(vec![1u8, 1], &Device::Cpu).unwrap(), + )) + } + } +} From 4dbdb9d44ce5a50e1f3a736513f211cde0e2c872 Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Wed, 27 Nov 2024 20:36:51 +0100 Subject: [PATCH 2/5] add sign operation and setup the registry --- candle-onnx/src/ops/math/mod.rs | 1 + candle-onnx/src/ops/math/sign.rs | 19 +++++++++++++++++++ candle-onnx/src/ops/mod.rs | 10 ++++++++++ 3 files changed, 30 insertions(+) create mode 100644 candle-onnx/src/ops/math/mod.rs create mode 100644 candle-onnx/src/ops/math/sign.rs diff --git a/candle-onnx/src/ops/math/mod.rs b/candle-onnx/src/ops/math/mod.rs new file mode 100644 index 0000000000..f8849e4278 --- /dev/null +++ b/candle-onnx/src/ops/math/mod.rs @@ -0,0 +1 @@ +pub(crate) mod sign; \ No newline at end of file diff --git a/candle-onnx/src/ops/math/sign.rs b/candle-onnx/src/ops/math/sign.rs new file mode 100644 index 0000000000..04c37ecfa7 --- /dev/null +++ b/candle-onnx/src/ops/math/sign.rs @@ -0,0 +1,19 @@ +use crate::ops::compute_node::ComputeNode; +use crate::ops::{OnnxOp, OnnxOpError, OpOutput}; +use crate::ops::OnnxOpError::ComputationFailed; + +pub(crate) struct Sign; +impl OnnxOp for Sign { + fn eval(&self, node: &ComputeNode) -> Result { + let input = node.get_input(0) + .ok_or_else(|| ComputationFailed("input 0 not found".to_string()))?; + + let output = input.sign() + .map_err(|err| ComputationFailed(format!("{:?}",err)))?; + + let output_name = node.get_output(0) + .ok_or_else(|| ComputationFailed("output 0 not found".to_string()))?; + + Ok((output_name.clone(), output)) + } +} \ No newline at end of file diff --git a/candle-onnx/src/ops/mod.rs b/candle-onnx/src/ops/mod.rs index d5f5676d02..f9ebeb5f2f 100644 --- a/candle-onnx/src/ops/mod.rs +++ b/candle-onnx/src/ops/mod.rs @@ -3,3 +3,13 @@ pub use onnxop::{OnnxOp, OnnxOpError, OnnxOpRegistry, OpOutput}; pub mod compute_node; pub use compute_node::ComputeNode; + + +mod math; +use math::sign; + +pub fn registry() -> Result { + let mut registry = OnnxOpRegistry::new(); + registry.insert("Sign", Box::new(sign::Sign))?; + Ok(registry) +} \ No newline at end of file From 88d9d6f6fa13a263e0c7dc485750491e52039ec1 Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Wed, 27 Nov 2024 20:44:51 +0100 Subject: [PATCH 3/5] use registry in current evaluation code --- candle-onnx/src/eval.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 2c60ed2f23..8a9dde7231 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,6 +1,7 @@ use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; use crate::onnx::{self, GraphProto}; +use crate::ops::{registry, ComputeNode}; use candle::{bail, DType, Device, Result, Tensor}; use std::collections::{HashMap, HashSet}; @@ -317,6 +318,8 @@ fn simple_eval_( ) } } + + let registry = registry()?; // The nodes are topologically sorted so we can just process them in order. for node in graph.node.iter() { let get = |input_name: &str| match values.get(input_name) { @@ -1950,7 +1953,12 @@ fn simple_eval_( let output = input.sign()?; values.insert(node.output[0].clone(), output); } - op_type => bail!("unsupported op_type {op_type} for op {node:?}"), + op_type => { + let onnx_op = registry.get(op_type)?; + let cn = ComputeNode::new(&node, values); + let (name, value) = onnx_op.eval(&cn)?; + values.insert(name, value); + } } } graph From dc0d2d50dca1b4d09dab8896bb99d4e10ccb53c8 Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Wed, 27 Nov 2024 20:45:03 +0100 Subject: [PATCH 4/5] cargo fmt --- candle-onnx/src/ops/math/mod.rs | 2 +- candle-onnx/src/ops/math/sign.rs | 15 +++++++++------ candle-onnx/src/ops/mod.rs | 3 +-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/candle-onnx/src/ops/math/mod.rs b/candle-onnx/src/ops/math/mod.rs index f8849e4278..171184640d 100644 --- a/candle-onnx/src/ops/math/mod.rs +++ b/candle-onnx/src/ops/math/mod.rs @@ -1 +1 @@ -pub(crate) mod sign; \ No newline at end of file +pub(crate) mod sign; diff --git a/candle-onnx/src/ops/math/sign.rs b/candle-onnx/src/ops/math/sign.rs index 04c37ecfa7..1e461d8f99 100644 --- a/candle-onnx/src/ops/math/sign.rs +++ b/candle-onnx/src/ops/math/sign.rs @@ -1,19 +1,22 @@ use crate::ops::compute_node::ComputeNode; -use crate::ops::{OnnxOp, OnnxOpError, OpOutput}; use crate::ops::OnnxOpError::ComputationFailed; +use crate::ops::{OnnxOp, OnnxOpError, OpOutput}; pub(crate) struct Sign; impl OnnxOp for Sign { fn eval(&self, node: &ComputeNode) -> Result { - let input = node.get_input(0) + let input = node + .get_input(0) .ok_or_else(|| ComputationFailed("input 0 not found".to_string()))?; - let output = input.sign() - .map_err(|err| ComputationFailed(format!("{:?}",err)))?; + let output = input + .sign() + .map_err(|err| ComputationFailed(format!("{:?}", err)))?; - let output_name = node.get_output(0) + let output_name = node + .get_output(0) .ok_or_else(|| ComputationFailed("output 0 not found".to_string()))?; Ok((output_name.clone(), output)) } -} \ No newline at end of file +} diff --git a/candle-onnx/src/ops/mod.rs b/candle-onnx/src/ops/mod.rs index f9ebeb5f2f..275fd167bd 100644 --- a/candle-onnx/src/ops/mod.rs +++ b/candle-onnx/src/ops/mod.rs @@ -4,7 +4,6 @@ pub use onnxop::{OnnxOp, OnnxOpError, OnnxOpRegistry, OpOutput}; pub mod compute_node; pub use compute_node::ComputeNode; - mod math; use math::sign; @@ -12,4 +11,4 @@ pub fn registry() -> Result { let mut registry = OnnxOpRegistry::new(); registry.insert("Sign", Box::new(sign::Sign))?; Ok(registry) -} \ No newline at end of file +} From 304ed5ff4a0789f79fe19bbf7139dc3f9728d872 Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Wed, 27 Nov 2024 20:51:55 +0100 Subject: [PATCH 5/5] add tests for sign operation --- candle-onnx/tests/ops.rs | 169 +++++++++++++------------------------ candle-onnx/tests/sign.rs | 43 ++++++++++ candle-onnx/tests/utils.rs | 17 ++++ 3 files changed, 117 insertions(+), 112 deletions(-) create mode 100644 candle-onnx/tests/sign.rs create mode 100644 candle-onnx/tests/utils.rs diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 3586bfbd68..b444fb64b0 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -8,30 +8,16 @@ use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, Value use candle_onnx::simple_eval; use std::collections::HashMap; +mod utils; + const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; const INPUT_A: &str = "a"; const OUTPUT_Z: &str = "z"; -fn create_model_proto_with_graph(graph: Option) -> ModelProto { - ModelProto { - metadata_props: vec![], - training_info: vec![], - functions: vec![], - ir_version: 0, - opset_import: vec![], - producer_name: "".to_string(), - producer_version: "".to_string(), - domain: "".to_string(), - model_version: 0, - doc_string: "".to_string(), - graph, - } -} - #[test] fn test_evaluation_fails_without_defined_graph() -> Result<()> { - let manual_graph = create_model_proto_with_graph(None); + let manual_graph = utils::create_model_proto_with_graph(None); let inputs: HashMap = HashMap::new(); match candle_onnx::simple_eval(&manual_graph, inputs) { Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"), @@ -43,7 +29,7 @@ fn test_evaluation_fails_without_defined_graph() -> Result<()> { // "Add" #[test] fn test_add_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Add".to_string(), domain: "".to_string(), @@ -83,7 +69,7 @@ fn test_add_operation() -> Result<()> { // "Sub" #[test] fn test_sub_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sub".to_string(), domain: "".to_string(), @@ -123,7 +109,7 @@ fn test_sub_operation() -> Result<()> { // "Mul" #[test] fn test_mul_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Mul".to_string(), domain: "".to_string(), @@ -163,7 +149,7 @@ fn test_mul_operation() -> Result<()> { // "Div" #[test] fn test_div_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Div".to_string(), domain: "".to_string(), @@ -203,7 +189,7 @@ fn test_div_operation() -> Result<()> { // "Exp" #[test] fn test_exp_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Exp".to_string(), domain: "".to_string(), @@ -249,7 +235,7 @@ fn test_exp_operation() -> Result<()> { // "Equal" #[test] fn test_equal_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Equal".to_string(), domain: "".to_string(), @@ -290,7 +276,7 @@ fn test_equal_operation() -> Result<()> { // "Not" #[test] fn test_not_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Not".to_string(), domain: "".to_string(), @@ -330,7 +316,7 @@ fn test_not_operation() -> Result<()> { // "MatMul" #[test] fn test_matmul_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "MatMul".to_string(), domain: "".to_string(), @@ -387,7 +373,7 @@ fn test_matmul_operation() -> Result<()> { // "Reshape" #[test] fn test_reshape_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Reshape".to_string(), domain: "".to_string(), @@ -454,7 +440,7 @@ fn test_reshape_operation() -> Result<()> { // "LogSoftmax" #[test] fn test_logsoftmax_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "LogSoftmax".to_string(), domain: "".to_string(), @@ -517,7 +503,7 @@ fn test_logsoftmax_operation() -> Result<()> { // "Softmax" #[test] fn test_softmax_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Softmax".to_string(), domain: "".to_string(), @@ -580,7 +566,7 @@ fn test_softmax_operation() -> Result<()> { // "Transpose" #[test] fn test_transpose_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Transpose".to_string(), domain: "".to_string(), @@ -640,7 +626,7 @@ fn test_transpose_operation() -> Result<()> { // "Dropout" #[test] fn test_dropout_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Dropout".to_string(), domain: "".to_string(), @@ -719,7 +705,7 @@ fn test_flatten_operation() -> Result<()> { sparse_tensors: vec![], type_protos: vec![], }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Flatten".to_string(), domain: "".to_string(), @@ -774,7 +760,7 @@ fn test_flatten_operation() -> Result<()> { assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]); att_axis.i = 1; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Flatten".to_string(), domain: "".to_string(), @@ -922,7 +908,7 @@ fn test_constant_of_shape() -> Result<()> { }) } - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ConstantOfShape".to_string(), domain: "".to_string(), @@ -974,7 +960,7 @@ fn test_constant_of_shape() -> Result<()> { // "Unsqueeze" #[test] fn test_unsqueeze() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Unsqueeze".to_string(), domain: "".to_string(), @@ -1112,7 +1098,7 @@ fn test_gather_operation() -> Result<()> { type_protos: vec![], }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Gather".to_string(), domain: "".to_string(), @@ -1269,7 +1255,7 @@ fn test_gather_elements() -> Result<()> { type_protos: vec![], }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "GatherElements".to_string(), domain: "".to_string(), @@ -1319,7 +1305,7 @@ fn test_gather_elements() -> Result<()> { // "Size" #[test] fn test_size_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Size".to_string(), domain: "".to_string(), @@ -1364,7 +1350,7 @@ fn test_size_operation() -> Result<()> { // "Shape" #[test] fn test_shape_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Shape".to_string(), domain: "".to_string(), @@ -1415,7 +1401,7 @@ fn test_shape_operation() -> Result<()> { // "Abs" #[test] fn test_abs_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Abs".to_string(), domain: "".to_string(), @@ -1473,7 +1459,7 @@ fn test_abs_operation() -> Result<()> { // "Cos" #[test] fn test_cos_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Cos".to_string(), domain: "".to_string(), @@ -1523,7 +1509,7 @@ fn test_cos_operation() -> Result<()> { // "Sin" #[test] fn test_sin_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sin".to_string(), domain: "".to_string(), @@ -1570,7 +1556,7 @@ fn test_sin_operation() -> Result<()> { // "Neg" #[test] fn test_neg_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Neg".to_string(), domain: "".to_string(), @@ -1627,7 +1613,7 @@ fn test_neg_operation() -> Result<()> { // "Tanh" #[test] fn test_tanh_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Tanh".to_string(), domain: "".to_string(), @@ -1684,7 +1670,7 @@ fn test_tanh_operation() -> Result<()> { // "Sigmoid" #[test] fn test_sigmoid_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sigmoid".to_string(), domain: "".to_string(), @@ -1741,7 +1727,7 @@ fn test_sigmoid_operation() -> Result<()> { // "Gelu" #[test] fn test_gelu_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Gelu".to_string(), domain: "".to_string(), @@ -1798,7 +1784,7 @@ fn test_gelu_operation() -> Result<()> { // "Relu" #[test] fn test_relu_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Relu".to_string(), domain: "".to_string(), @@ -2289,7 +2275,7 @@ fn test_reduce_max() -> Result<()> { }); } - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ReduceMax".to_string(), domain: "".to_string(), @@ -2808,7 +2794,7 @@ fn test_reduce_min() -> Result<()> { }); } - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ReduceMin".to_string(), domain: "".to_string(), @@ -3016,7 +3002,7 @@ fn test_reduce_mean() -> Result<()> { type_protos: vec![], }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ReduceMean".to_string(), domain: "".to_string(), @@ -3074,7 +3060,7 @@ fn test_sqrt() -> Result<()> { test(&[1., 4., 9.], &[1., 2., 3.])?; fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sqrt".to_string(), domain: "".to_string(), @@ -3220,7 +3206,7 @@ fn test_random_uniform() -> Result<()> { } mut_attrs }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "RandomUniform".to_string(), domain: "".to_string(), @@ -3366,7 +3352,7 @@ fn test_random_normal() -> Result<()> { } mut_attrs }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "RandomNormal".to_string(), domain: "".to_string(), @@ -3426,7 +3412,7 @@ fn test_range() -> Result<()> { delta: impl NdArray, expected: impl NdArray, ) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Range".to_string(), domain: "".to_string(), @@ -3492,7 +3478,7 @@ fn test_greater() -> Result<()> { test(&[1., 2., 3.], 2., &[0u8, 0, 1])?; fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Greater".to_string(), domain: "".to_string(), @@ -3553,7 +3539,7 @@ fn test_less() -> Result<()> { test(&[1., 2., 3.], 2., &[1u8, 0, 0])?; fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Less".to_string(), domain: "".to_string(), @@ -3611,7 +3597,7 @@ fn test_log() -> Result<()> { test(&[1., 10.], &[0., std::f64::consts::LN_10])?; fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Log".to_string(), domain: "".to_string(), @@ -3670,7 +3656,7 @@ fn test_min() -> Result<()> { c: impl NdArray, expected: impl NdArray, ) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Min".to_string(), domain: "".to_string(), @@ -3748,7 +3734,7 @@ fn test_where() -> Result<()> { y: impl NdArray, expected: impl NdArray, ) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Where".to_string(), domain: "".to_string(), @@ -3806,7 +3792,7 @@ fn test_where() -> Result<()> { #[test] fn test_floor() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Floor".to_string(), domain: "".to_string(), @@ -3882,7 +3868,7 @@ fn test_floor() -> Result<()> { #[test] fn test_ceil() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Ceil".to_string(), domain: "".to_string(), @@ -4097,7 +4083,7 @@ fn test_argmin() -> Result<()> { } mut_attrs }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ArgMin".to_string(), domain: "".to_string(), @@ -4279,7 +4265,7 @@ fn test_argmax() -> Result<()> { } mut_attrs }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ArgMax".to_string(), domain: "".to_string(), @@ -4354,7 +4340,7 @@ fn test_leakyrelu() -> Result<()> { } mut_attrs }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "LeakyRelu".to_string(), domain: "".to_string(), @@ -4465,7 +4451,7 @@ fn test_if() -> Result<()> { }], ..GraphProto::default() }; - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "If".to_string(), attribute: vec![ @@ -4530,7 +4516,7 @@ fn test_pad() -> Result<()> { &Device::Cpu, )?; - let model = create_model_proto_with_graph(Some(GraphProto { + let model = utils::create_model_proto_with_graph(Some(GraphProto { input: vec![ ValueInfoProto { name: "data".to_string(), @@ -4572,7 +4558,7 @@ fn test_pad() -> Result<()> { #[test] fn test_slice() -> Result<()> { - let model = create_model_proto_with_graph(Some(GraphProto { + let model = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Slice".to_string(), input: vec![ @@ -4657,7 +4643,7 @@ fn test_slice() -> Result<()> { [2, 3, 4], ] */ - let model = create_model_proto_with_graph(Some(GraphProto { + let model = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Slice".to_string(), input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()], @@ -5074,7 +5060,7 @@ fn test_lstm() -> Result<()> { )?; // end of generated values - let model = create_model_proto_with_graph(Some(GraphProto { + let model = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "LSTM".to_string(), name: "LSTM_test".to_string(), @@ -5172,7 +5158,7 @@ fn test_lstm() -> Result<()> { #[test] fn test_expand_dim_changed() -> Result<()> { // Create a manual graph for the Expand operation - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Expand".to_string(), domain: "".to_string(), @@ -5241,7 +5227,7 @@ fn make_graph_helper( outputs: &[&str], attribs: Vec, ) -> ModelProto { - create_model_proto_with_graph(Some(GraphProto { + utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: op_name.to_string(), domain: "".to_string(), @@ -5805,7 +5791,7 @@ fn test_xor() -> Result<()> { )?; fn test(input: impl NdArray, other: impl NdArray, expected: impl NdArray) -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Xor".to_string(), domain: "".to_string(), @@ -5869,44 +5855,3 @@ fn test_xor() -> Result<()> { } Ok(()) } - -#[test] -fn test_sign_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { - node: vec![NodeProto { - op_type: "Sign".to_string(), - domain: "".to_string(), - attribute: vec![], - input: vec![INPUT_X.to_string()], - output: vec![OUTPUT_Z.to_string()], - name: "".to_string(), - doc_string: "".to_string(), - }], - name: "".to_string(), - initializer: vec![], - input: vec![], - output: vec![ValueInfoProto { - name: OUTPUT_Z.to_string(), - doc_string: "".to_string(), - r#type: None, - }], - value_info: vec![], - doc_string: "".to_string(), - sparse_initializer: vec![], - quantization_annotation: vec![], - })); - - let mut inputs: HashMap = HashMap::new(); - inputs.insert( - INPUT_X.to_string(), - Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?, - ); - let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; - - let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - assert_eq!( - z.to_dtype(candle::DType::I64)?.to_vec1::()?.to_vec(), - vec![-1, -1, 0, 1, 1] - ); - Ok(()) -} diff --git a/candle-onnx/tests/sign.rs b/candle-onnx/tests/sign.rs new file mode 100644 index 0000000000..04ff4b1cae --- /dev/null +++ b/candle-onnx/tests/sign.rs @@ -0,0 +1,43 @@ +use candle::{Device, Result, Tensor}; +use candle_onnx::onnx::{GraphProto, NodeProto, ValueInfoProto}; +use std::collections::HashMap; +mod utils; +#[test] +fn test_sign_operation() -> Result<()> { + let manual_graph = utils::create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sign".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec!["X".to_string()], + output: vec!["Z".to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: "Z".to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert("X".to_string(), Tensor::arange(-2., 3., &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + + let z = eval.get("Z").expect("Output 'z' not found"); + assert_eq!( + z.to_dtype(candle::DType::I64)?.to_vec1::()?.to_vec(), + vec![-1, -1, 0, 1, 1] + ); + + Ok(()) +} diff --git a/candle-onnx/tests/utils.rs b/candle-onnx/tests/utils.rs new file mode 100644 index 0000000000..216dfe5871 --- /dev/null +++ b/candle-onnx/tests/utils.rs @@ -0,0 +1,17 @@ +use candle_onnx::onnx::{GraphProto, ModelProto}; + +pub fn create_model_proto_with_graph(graph: Option) -> ModelProto { + ModelProto { + metadata_props: vec![], + training_info: vec![], + functions: vec![], + ir_version: 0, + opset_import: vec![], + producer_name: "".to_string(), + producer_version: "".to_string(), + domain: "".to_string(), + model_version: 0, + doc_string: "".to_string(), + graph, + } +}