diff --git a/scaflow/graph_nodes/nodes/constants/file_node.py b/scaflow/graph_nodes/nodes/constants/file_node.py index 615f9bf..00af834 100644 --- a/scaflow/graph_nodes/nodes/constants/file_node.py +++ b/scaflow/graph_nodes/nodes/constants/file_node.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List, Type from scaflow.model.node import Node from scaflow.model.output_socket import Output @@ -17,7 +17,7 @@ def __init__(self, name=None): @classmethod def create_node(cls): c = cls() - output = Output("filename", "File") + output = Output("filename", "File", return_type="str") c.add_output(output) c.add_control(FileControl("file_control", "File Path")) return c diff --git a/scaflow/graph_nodes/nodes/cpa_attack.py b/scaflow/graph_nodes/nodes/cpa_attack.py index 32f0a16..3c94303 100644 --- a/scaflow/graph_nodes/nodes/cpa_attack.py +++ b/scaflow/graph_nodes/nodes/cpa_attack.py @@ -1,8 +1,11 @@ import logging -from typing import Dict +from typing import Any, Callable, Dict, Type +from estraces import TraceHeaderSet import numpy as np import scared +from scared import Model +from scared.selection_functions import SelectionFunction from scaflow.model import Input, Node, Output, dispatcher @@ -19,22 +22,26 @@ def __init__(self, name=None): @classmethod def create_node(cls): n = cls() - traces_input = Input("traces", "Traces") - traces_input.add_compatible("peaks") - n.add_input(traces_input) - n.add_input(Input("selection", "Selection function")) - n.add_input(Input("model", "Model")) - n.add_input(Input("discriminant", "Discriminant")) - n.add_output(Output("output", "Data")) + n.add_input(Input("traces", "Traces", accepted_types=["TraceHeaderSet"])) + n.add_input( + Input( + "selection", + "Selection function", + accepted_types=["selection"], + ) + ) + n.add_input(Input("model", "Model", accepted_types=["model"])) + n.add_input( + Input("discriminant", "Discriminant", accepted_types=["discriminant"]) + ) + n.add_output(Output("output", "Data", return_type="bool")) return n def execute(self, kwargs) -> Dict[str, any]: - traces = kwargs.get("traces") + traces: TraceHeaderSet = kwargs.get("traces") selection = kwargs.get("selection") - model = kwargs.get("model") + model: Type[Model] = kwargs.get("model") discriminant = kwargs.get("discriminant") - # logger.debug(kwargs) - # logger.debug(traces, selection, model, discriminant) att = scared.CPAAttack( selection_function=selection(), model=model(), diff --git a/scaflow/graph_nodes/nodes/discriminants/max_abs.py b/scaflow/graph_nodes/nodes/discriminants/max_abs.py index 1ebe7b3..778327f 100644 --- a/scaflow/graph_nodes/nodes/discriminants/max_abs.py +++ b/scaflow/graph_nodes/nodes/discriminants/max_abs.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Callable, Dict import scared @@ -17,8 +17,8 @@ def __init__(self, name=None): @classmethod def create_node(cls): n = cls() - n.add_output(Output("discriminant", "Data")) + n.add_output(Output("discriminant", "Data", return_type="discriminant")) return n - def execute(self, kwargs) -> Dict[str, any]: + def execute(self, kwargs) -> Callable[[Any, Any], Any]: return scared.maxabs diff --git a/scaflow/graph_nodes/nodes/model/hamming_weight.py b/scaflow/graph_nodes/nodes/model/hamming_weight.py index 54f5d2c..dc14a97 100644 --- a/scaflow/graph_nodes/nodes/model/hamming_weight.py +++ b/scaflow/graph_nodes/nodes/model/hamming_weight.py @@ -1,4 +1,7 @@ +from typing import Type + import scared +from scared import Model from scaflow.model.dispatcher import dispatcher from scaflow.model.node import Node @@ -15,8 +18,8 @@ def __init__(self, name=None): @classmethod def create_node(cls): n = cls() - n.add_output(Output("model", "Hamming Weights")) + n.add_output(Output("model", "Hamming Weights", return_type="model")) return n - def execute(self, kwargs): + def execute(self, kwargs) -> Type[Model]: return scared.HammingWeight diff --git a/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py b/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py index d1a7911..37beefd 100644 --- a/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py +++ b/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py @@ -1,4 +1,7 @@ +from typing import Type + import scared +from scared.selection_functions import SelectionFunction from scaflow.model.dispatcher import dispatcher from scaflow.model.node import Node @@ -15,8 +18,10 @@ def __init__(self, name=None): @classmethod def create_node(cls): n = cls() - n.add_output(Output("selection", "Data")) + n.add_output(Output("selection", "Data", return_type="selection")) return n - def execute(self, kwargs): + def execute(self, kwargs) -> Type[SelectionFunction]: + # Disable inspection as the type checker cannot determine that the class is wrapped with SelectionFunction + # noinspection PyTypeChecker return scared.aes.selection_functions.encrypt.FirstSubBytes diff --git a/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py b/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py index 516a4ac..33330f3 100644 --- a/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py +++ b/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py @@ -1,6 +1,8 @@ import logging from pathlib import Path +from typing import Dict, List, Type +from estraces import TraceHeaderSet import scared from scaflow.model.dispatcher import dispatcher @@ -22,8 +24,8 @@ def __init__(self, name="Trace Input"): @classmethod def create_node(cls): n = cls() - n.add_input(Input("filename", "Trace File")) - n.add_output(Output("traces", "Output")) + n.add_input(Input("filename", "Trace File", accepted_types=["str"])) + n.add_output(Output("traces", "Output", return_type="TraceHeaderSet")) return n def execute(self, kwargs): diff --git a/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py b/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py index d4b9114..52bfdb2 100644 --- a/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py +++ b/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py @@ -2,6 +2,7 @@ from pathlib import Path import estraces +from estraces import TraceHeaderSet import numpy as np from scaflow.model.dispatcher import dispatcher @@ -23,10 +24,10 @@ def __init__(self, name="Trace Input"): @classmethod def create_node(cls): n = cls() - n.add_input(Input("filename", "Trace File")) - n.add_input(Input("plaintext", "Plaintext File")) - n.add_input(Input("ciphertext", "Ciphertext File")) - n.add_output(Output("traces", "Output")) + n.add_input(Input("filename", "Trace File", accepted_types=["str"])) + n.add_input(Input("plaintext", "Plaintext File", accepted_types=["str"])) + n.add_input(Input("ciphertext", "Ciphertext File", accepted_types=["str"])) + n.add_output(Output("traces", "Output", return_type="TraceHeaderSet")) return n def execute(self, kwargs): diff --git a/scaflow/model/input_socket.py b/scaflow/model/input_socket.py index 2b7401d..c11c119 100644 --- a/scaflow/model/input_socket.py +++ b/scaflow/model/input_socket.py @@ -1,6 +1,6 @@ -from typing import Optional, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Type -from scaflow.model import dispatcher +from scaflow.model import Output, dispatcher from .socket import Socket if TYPE_CHECKING: @@ -11,13 +11,21 @@ @dispatcher class Input(Socket): - def __init__(self, key: str, name: str, multi_conns: bool = False) -> None: + def __init__( + self, key: str, name: str, accepted_types: List[str], multi_conns: bool = False + ) -> None: super().__init__(key, name, multi_conns) self._control: Optional["Control"] = None + self._accepted_types = accepted_types def __repr__(self): return f'' + def compatible_with(self, socket: "Socket"): + if isinstance(socket, Output): + return socket.return_type in self._accepted_types + return False + def add_connection(self, conn: "Connection"): if not self.multi_conns and self.has_connection(): raise Exception("Multiple connections not allowed") @@ -30,11 +38,17 @@ def as_dict(self) -> "InputDict": "compatible": self._compatible, "multi_conns": self.multi_conns, "control": self._control, + "accepted_types": self._accepted_types, } @classmethod def from_dict(cls, data: "InputDict"): - c = cls(data["key"], data["name"], data["multi_conns"]) + c = cls( + data["key"], + data["name"], + accepted_types=data["accepted_types"], + multi_conns=data["multi_conns"], + ) c._control = data["control"] c._compatible = data["compatible"] return c diff --git a/scaflow/model/output_socket.py b/scaflow/model/output_socket.py index 93f6890..73368ef 100644 --- a/scaflow/model/output_socket.py +++ b/scaflow/model/output_socket.py @@ -2,7 +2,9 @@ from .socket import Socket from .connection import Connection -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Type + +from .type_hints import OutputDict, SocketDict if TYPE_CHECKING: from .input_socket import Input @@ -10,12 +12,18 @@ @dispatcher class Output(Socket): - def __init__(self, key: str, name: str, multi_conns: bool = True) -> None: + def __init__( + self, key: str, name: str, return_type: str, multi_conns: bool = True + ) -> None: super().__init__(key, name, multi_conns) + self.return_type = return_type def __repr__(self): return f'' + def compatible_with(self, socket: "Socket"): + return socket.compatible_with(self) + def add_connection(self, input_socket: "Input"): if not self.compatible_with(input_socket): raise TypeError("Not compatible with socket") @@ -33,3 +41,23 @@ def add_connection(self, input_socket: "Input"): self.connections.append(connection) input_socket.add_connection(connection) return connection + + def as_dict(self) -> OutputDict: + return { + "key": self.key, + "name": self.display_name, + "compatible": self._compatible, + "multi_conns": self.multi_conns, + "return_type": self.return_type, + } + + @classmethod + def from_dict(cls, data: OutputDict): + c = cls( + data["key"], + data["name"], + return_type=data["return_type"], + multi_conns=data["multi_conns"], + ) + c._compatible = data["compatible"] + return c diff --git a/scaflow/model/socket.py b/scaflow/model/socket.py index d8228ec..653bd55 100644 --- a/scaflow/model/socket.py +++ b/scaflow/model/socket.py @@ -38,10 +38,6 @@ def __eq__(self, other: object) -> bool: ) return NotImplemented - def compatible_with(self, socket: "Socket"): - # logger.debug("%s, %s", self.key, socket.key) - return self.key == socket.key or socket.key in self._compatible - def has_connection(self): return len(self.connections) > 0 @@ -55,19 +51,3 @@ def remove_connections(self): def add_compatible(self, compatible: str): if compatible not in self._compatible: self._compatible.append(compatible) - - def as_dict(self) -> SocketDict: - return { - "key": self.key, - "name": self.display_name, - "compatible": self._compatible, - "multi_conns": self.multi_conns, - } - - @classmethod - def from_dict(cls, data: SocketDict): - socket = cls( - key=data["key"], name=data["name"], multi_conns=data["multi_conns"] - ) - socket._compatible = data.get("compatible", []) - return socket diff --git a/scaflow/model/type_hints.py b/scaflow/model/type_hints.py index cd60455..f8df4a6 100644 --- a/scaflow/model/type_hints.py +++ b/scaflow/model/type_hints.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, TypedDict +from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypedDict if TYPE_CHECKING: from scaflow.model import Connection, Control, Input, Node, Output @@ -29,6 +29,11 @@ class SocketDict(TypedDict): class InputDict(SocketDict): control: Optional["Control"] + accepted_types: List[str] + + +class OutputDict(SocketDict): + return_type: str class NodeDict(TypedDict): diff --git a/tests/test_graph_connection.py b/tests/test_graph_connection.py index 23801b1..7ff6d69 100644 --- a/tests/test_graph_connection.py +++ b/tests/test_graph_connection.py @@ -15,8 +15,8 @@ class TestGraphConnection: def test_serialization(self): - o = Output("test", "Output") - i = Input("test", "Input") + o = Output("test", "Output", return_type=str) + i = Input("test", "Input", accepted_types=[str]) c = Connection( output_socket_key=o.key, input_socket_key=i.key, input_node=0, output_node=1 ) @@ -35,8 +35,8 @@ def test_deserialization(self): assert n.output_node == 1 def test_repr(self): - o = Output("test", "Output") - i = Input("test", "Input") + o = Output("test", "Output", return_type=str) + i = Input("test", "Input", accepted_types=[str]) c = Connection( output_socket_key=o.key, input_socket_key=i.key, input_node=0, output_node=1 ) diff --git a/tests/test_graph_model.py b/tests/test_graph_model.py index 4303483..84e41b9 100644 --- a/tests/test_graph_model.py +++ b/tests/test_graph_model.py @@ -38,6 +38,7 @@ def execute(self, kwargs) -> Dict[str, any]: "name": "Output", "compatible": [], "multi_conns": true, + "return_type": "str", "__class__": "Output" } ], @@ -66,6 +67,9 @@ def execute(self, kwargs) -> Dict[str, any]: "compatible": [], "multi_conns": false, "control": null, + "accepted_types": [ + "str" + ], "__class__": "Input" } ], @@ -109,11 +113,13 @@ def test_add_node(self): def test_add_edge(self): g = Graph() n = ExampleNode("Test") - input_socket = Input("compatible", "Input", multi_conns=True) + input_socket = Input( + "compatible", "Input", multi_conns=True, accepted_types=["str"] + ) n.add_input(input_socket) g.add_node(n) n2 = ExampleNode("Test2") - output = Output("compatible", "Output") + output = Output("compatible", "Output", return_type="str") n2.add_output(output) g.add_node(n2) g.add_edge(output_socket=output, input_socket=input_socket) @@ -122,7 +128,7 @@ def test_add_edge(self): assert n.inputs["compatible"].has_connection() with pytest.raises(Exception): - n2.add_output(Output("compatible", "Output")) + n2.add_output(Output("compatible", "Output", return_type="str")) # g.add_edge(output_socket=n2.outputs["output2"], input_socket=n.inputs["input"]) def test_graph_callback(self): @@ -143,10 +149,10 @@ def test_serialization(self): Connection._last_conn_id = 0 g = Graph() n = ExampleNode("Test") - n.add_output(Output("link", "Output")) + n.add_output(Output("link", "Output", return_type="str")) n.add_control(Control("control", ControlType.FilePath, "File")) n2 = ExampleNode("Test") - n2.add_input(Input("link", "Input")) + n2.add_input(Input("link", "Input", accepted_types=["str"])) g.add_node(n) g.add_node(n2) @@ -180,29 +186,3 @@ def test_deserialization(self): assert e.id == 1 assert e.input_socket_key == "link" assert e.output_socket_key == "link" - - -a = """{ - "edges": [ - { - "id": 1, - "input_socket": { - "key": "link", - "name": "Input", - "compatible": [], - "multi_conns": false, - "control": null, - "__class__": "Input" - }, - "output_socket": { - "key": "link", - "name": "Output", - "compatible": [], - "multi_conns": true, - "__class__": "Output" - }, - "__class__": "Connection" - } - ], - "__class__": "Graph" -}""" diff --git a/tests/test_graph_node.py b/tests/test_graph_node.py index 395058e..c4d8b27 100644 --- a/tests/test_graph_node.py +++ b/tests/test_graph_node.py @@ -4,6 +4,7 @@ import pytest from scaflow import model +from scaflow.model import Input from scaflow.model.node import Spacing from scaflow.graph_nodes.controls import FileControl from scaflow.model.dispatcher import dispatcher @@ -29,6 +30,9 @@ def execute(self, kwargs) -> Dict[str, any]: "compatible": [], "multi_conns": false, "control": null, + "accepted_types": [ + "str" + ], "__class__": "Input" } ], @@ -38,6 +42,7 @@ def execute(self, kwargs) -> Dict[str, any]: "name": "", "compatible": [], "multi_conns": true, + "return_type": "str", "__class__": "Output" } ], @@ -89,7 +94,7 @@ def test_spacing_t_r_b_l(self): def test_inputs(self): n = ExampleNode("Test") - i = model.Input("input", "") + i = model.Input("input", "", accepted_types=["str"]) n.add_input(i) assert len(n.inputs) == 1 assert i.key in n.inputs @@ -98,7 +103,7 @@ def test_inputs(self): def test_outputs(self): n = ExampleNode("Test") - o = model.Output("output", "") + o = model.Output("output", "", return_type="str") n.add_output(o) assert len(n.outputs) == 1 assert o.key in n.outputs @@ -107,13 +112,13 @@ def test_outputs(self): def test_duplicate_key_raises_exception(self): n = ExampleNode("Test") - n.add_input(model.Input("input", "")) + n.add_input(model.Input("input", "", accepted_types=["str"])) with pytest.raises(Exception): - n.add_input(model.Input("input", "")) + n.add_input(model.Input("input", "", accepted_types=["str"])) def test_already_assigned_node_raises_error(self): n = ExampleNode("Test") - i = model.Input("input", "") + i = model.Input("input", "", accepted_types=["str"]) n.add_input(i) n2 = ExampleNode("Test2") with pytest.raises(Exception): @@ -137,13 +142,12 @@ def test_controls(self): def test_serialization(self): model.Node._last_node_id = 0 n = ExampleNode("Test") - n.add_input(model.Input("input", "")) + n.add_input(Input("input", "", accepted_types=["str"])) n.add_control(FileControl("file", "File")) - n.add_output(model.Output("output", "")) + n.add_output(model.Output("output", "", return_type="str")) n.position = (100, 100) json_data = json.dumps(n, default=dispatcher.encoder_default, indent=4) - print(json_data) assert json_data == EXAMPLE_JSON def test_deserialization(self): diff --git a/tests/test_graph_serialisation.py b/tests/test_graph_serialisation.py index 14cb0bc..4bd3f88 100644 --- a/tests/test_graph_serialisation.py +++ b/tests/test_graph_serialisation.py @@ -7,7 +7,7 @@ class TestGraphSerialisation: def test_socket_json(self): - output = Output("key", "name") + output = Output("key", "name", return_type="str") json_data = json.dumps(output, default=dispatcher.encoder_default) output2: Output = json.loads(json_data, object_hook=dispatcher.decoder_hook)