diff --git a/scaflow/graph_nodes/nodes/constants/file_node.py b/scaflow/graph_nodes/nodes/constants/file_node.py
index 615f9bf..00af834 100644
--- a/scaflow/graph_nodes/nodes/constants/file_node.py
+++ b/scaflow/graph_nodes/nodes/constants/file_node.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Dict, List, Type
from scaflow.model.node import Node
from scaflow.model.output_socket import Output
@@ -17,7 +17,7 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
c = cls()
- output = Output("filename", "File")
+ output = Output("filename", "File", return_type="str")
c.add_output(output)
c.add_control(FileControl("file_control", "File Path"))
return c
diff --git a/scaflow/graph_nodes/nodes/cpa_attack.py b/scaflow/graph_nodes/nodes/cpa_attack.py
index 32f0a16..3c94303 100644
--- a/scaflow/graph_nodes/nodes/cpa_attack.py
+++ b/scaflow/graph_nodes/nodes/cpa_attack.py
@@ -1,8 +1,11 @@
import logging
-from typing import Dict
+from typing import Any, Callable, Dict, Type
+from estraces import TraceHeaderSet
import numpy as np
import scared
+from scared import Model
+from scared.selection_functions import SelectionFunction
from scaflow.model import Input, Node, Output, dispatcher
@@ -19,22 +22,26 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
- traces_input = Input("traces", "Traces")
- traces_input.add_compatible("peaks")
- n.add_input(traces_input)
- n.add_input(Input("selection", "Selection function"))
- n.add_input(Input("model", "Model"))
- n.add_input(Input("discriminant", "Discriminant"))
- n.add_output(Output("output", "Data"))
+ n.add_input(Input("traces", "Traces", accepted_types=["TraceHeaderSet"]))
+ n.add_input(
+ Input(
+ "selection",
+ "Selection function",
+ accepted_types=["selection"],
+ )
+ )
+ n.add_input(Input("model", "Model", accepted_types=["model"]))
+ n.add_input(
+ Input("discriminant", "Discriminant", accepted_types=["discriminant"])
+ )
+ n.add_output(Output("output", "Data", return_type="bool"))
return n
def execute(self, kwargs) -> Dict[str, any]:
- traces = kwargs.get("traces")
+ traces: TraceHeaderSet = kwargs.get("traces")
selection = kwargs.get("selection")
- model = kwargs.get("model")
+ model: Type[Model] = kwargs.get("model")
discriminant = kwargs.get("discriminant")
- # logger.debug(kwargs)
- # logger.debug(traces, selection, model, discriminant)
att = scared.CPAAttack(
selection_function=selection(),
model=model(),
diff --git a/scaflow/graph_nodes/nodes/discriminants/max_abs.py b/scaflow/graph_nodes/nodes/discriminants/max_abs.py
index 1ebe7b3..778327f 100644
--- a/scaflow/graph_nodes/nodes/discriminants/max_abs.py
+++ b/scaflow/graph_nodes/nodes/discriminants/max_abs.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Any, Callable, Dict
import scared
@@ -17,8 +17,8 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
- n.add_output(Output("discriminant", "Data"))
+ n.add_output(Output("discriminant", "Data", return_type="discriminant"))
return n
- def execute(self, kwargs) -> Dict[str, any]:
+ def execute(self, kwargs) -> Callable[[Any, Any], Any]:
return scared.maxabs
diff --git a/scaflow/graph_nodes/nodes/model/hamming_weight.py b/scaflow/graph_nodes/nodes/model/hamming_weight.py
index 54f5d2c..dc14a97 100644
--- a/scaflow/graph_nodes/nodes/model/hamming_weight.py
+++ b/scaflow/graph_nodes/nodes/model/hamming_weight.py
@@ -1,4 +1,7 @@
+from typing import Type
+
import scared
+from scared import Model
from scaflow.model.dispatcher import dispatcher
from scaflow.model.node import Node
@@ -15,8 +18,8 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
- n.add_output(Output("model", "Hamming Weights"))
+ n.add_output(Output("model", "Hamming Weights", return_type="model"))
return n
- def execute(self, kwargs):
+ def execute(self, kwargs) -> Type[Model]:
return scared.HammingWeight
diff --git a/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py b/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py
index d1a7911..37beefd 100644
--- a/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py
+++ b/scaflow/graph_nodes/nodes/selection/first_sub_bytes.py
@@ -1,4 +1,7 @@
+from typing import Type
+
import scared
+from scared.selection_functions import SelectionFunction
from scaflow.model.dispatcher import dispatcher
from scaflow.model.node import Node
@@ -15,8 +18,10 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
- n.add_output(Output("selection", "Data"))
+ n.add_output(Output("selection", "Data", return_type="selection"))
return n
- def execute(self, kwargs):
+ def execute(self, kwargs) -> Type[SelectionFunction]:
+ # Disable inspection as the type checker cannot determine that the class is wrapped with SelectionFunction
+ # noinspection PyTypeChecker
return scared.aes.selection_functions.encrypt.FirstSubBytes
diff --git a/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py b/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py
index 516a4ac..33330f3 100644
--- a/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py
+++ b/scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py
@@ -1,6 +1,8 @@
import logging
from pathlib import Path
+from typing import Dict, List, Type
+from estraces import TraceHeaderSet
import scared
from scaflow.model.dispatcher import dispatcher
@@ -22,8 +24,8 @@ def __init__(self, name="Trace Input"):
@classmethod
def create_node(cls):
n = cls()
- n.add_input(Input("filename", "Trace File"))
- n.add_output(Output("traces", "Output"))
+ n.add_input(Input("filename", "Trace File", accepted_types=["str"]))
+ n.add_output(Output("traces", "Output", return_type="TraceHeaderSet"))
return n
def execute(self, kwargs):
diff --git a/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py b/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py
index d4b9114..52bfdb2 100644
--- a/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py
+++ b/scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py
@@ -2,6 +2,7 @@
from pathlib import Path
import estraces
+from estraces import TraceHeaderSet
import numpy as np
from scaflow.model.dispatcher import dispatcher
@@ -23,10 +24,10 @@ def __init__(self, name="Trace Input"):
@classmethod
def create_node(cls):
n = cls()
- n.add_input(Input("filename", "Trace File"))
- n.add_input(Input("plaintext", "Plaintext File"))
- n.add_input(Input("ciphertext", "Ciphertext File"))
- n.add_output(Output("traces", "Output"))
+ n.add_input(Input("filename", "Trace File", accepted_types=["str"]))
+ n.add_input(Input("plaintext", "Plaintext File", accepted_types=["str"]))
+ n.add_input(Input("ciphertext", "Ciphertext File", accepted_types=["str"]))
+ n.add_output(Output("traces", "Output", return_type="TraceHeaderSet"))
return n
def execute(self, kwargs):
diff --git a/scaflow/model/input_socket.py b/scaflow/model/input_socket.py
index 2b7401d..c11c119 100644
--- a/scaflow/model/input_socket.py
+++ b/scaflow/model/input_socket.py
@@ -1,6 +1,6 @@
-from typing import Optional, TYPE_CHECKING
+from typing import List, Optional, TYPE_CHECKING, Type
-from scaflow.model import dispatcher
+from scaflow.model import Output, dispatcher
from .socket import Socket
if TYPE_CHECKING:
@@ -11,13 +11,21 @@
@dispatcher
class Input(Socket):
- def __init__(self, key: str, name: str, multi_conns: bool = False) -> None:
+ def __init__(
+ self, key: str, name: str, accepted_types: List[str], multi_conns: bool = False
+ ) -> None:
super().__init__(key, name, multi_conns)
self._control: Optional["Control"] = None
+ self._accepted_types = accepted_types
def __repr__(self):
return f''
+ def compatible_with(self, socket: "Socket"):
+ if isinstance(socket, Output):
+ return socket.return_type in self._accepted_types
+ return False
+
def add_connection(self, conn: "Connection"):
if not self.multi_conns and self.has_connection():
raise Exception("Multiple connections not allowed")
@@ -30,11 +38,17 @@ def as_dict(self) -> "InputDict":
"compatible": self._compatible,
"multi_conns": self.multi_conns,
"control": self._control,
+ "accepted_types": self._accepted_types,
}
@classmethod
def from_dict(cls, data: "InputDict"):
- c = cls(data["key"], data["name"], data["multi_conns"])
+ c = cls(
+ data["key"],
+ data["name"],
+ accepted_types=data["accepted_types"],
+ multi_conns=data["multi_conns"],
+ )
c._control = data["control"]
c._compatible = data["compatible"]
return c
diff --git a/scaflow/model/output_socket.py b/scaflow/model/output_socket.py
index 93f6890..73368ef 100644
--- a/scaflow/model/output_socket.py
+++ b/scaflow/model/output_socket.py
@@ -2,7 +2,9 @@
from .socket import Socket
from .connection import Connection
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Type
+
+from .type_hints import OutputDict, SocketDict
if TYPE_CHECKING:
from .input_socket import Input
@@ -10,12 +12,18 @@
@dispatcher
class Output(Socket):
- def __init__(self, key: str, name: str, multi_conns: bool = True) -> None:
+ def __init__(
+ self, key: str, name: str, return_type: str, multi_conns: bool = True
+ ) -> None:
super().__init__(key, name, multi_conns)
+ self.return_type = return_type
def __repr__(self):
return f'