Skip to content

Commit

Permalink
Add accepted/returned types to sockets, type check connections
Browse files Browse the repository at this point in the history
  • Loading branch information
furgoose committed Mar 31, 2021
1 parent aafbf83 commit c8cc29f
Show file tree
Hide file tree
Showing 15 changed files with 127 additions and 98 deletions.
4 changes: 2 additions & 2 deletions scaflow/graph_nodes/nodes/constants/file_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, List, Type

from scaflow.model.node import Node
from scaflow.model.output_socket import Output
Expand All @@ -17,7 +17,7 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
c = cls()
output = Output("filename", "File")
output = Output("filename", "File", return_type="str")
c.add_output(output)
c.add_control(FileControl("file_control", "File Path"))
return c
Expand Down
31 changes: 19 additions & 12 deletions scaflow/graph_nodes/nodes/cpa_attack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from typing import Dict
from typing import Any, Callable, Dict, Type

from estraces import TraceHeaderSet
import numpy as np
import scared
from scared import Model
from scared.selection_functions import SelectionFunction

from scaflow.model import Input, Node, Output, dispatcher

Expand All @@ -19,22 +22,26 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
traces_input = Input("traces", "Traces")
traces_input.add_compatible("peaks")
n.add_input(traces_input)
n.add_input(Input("selection", "Selection function"))
n.add_input(Input("model", "Model"))
n.add_input(Input("discriminant", "Discriminant"))
n.add_output(Output("output", "Data"))
n.add_input(Input("traces", "Traces", accepted_types=["TraceHeaderSet"]))
n.add_input(
Input(
"selection",
"Selection function",
accepted_types=["selection"],
)
)
n.add_input(Input("model", "Model", accepted_types=["model"]))
n.add_input(
Input("discriminant", "Discriminant", accepted_types=["discriminant"])
)
n.add_output(Output("output", "Data", return_type="bool"))
return n

def execute(self, kwargs) -> Dict[str, any]:
traces = kwargs.get("traces")
traces: TraceHeaderSet = kwargs.get("traces")
selection = kwargs.get("selection")
model = kwargs.get("model")
model: Type[Model] = kwargs.get("model")
discriminant = kwargs.get("discriminant")
# logger.debug(kwargs)
# logger.debug(traces, selection, model, discriminant)
att = scared.CPAAttack(
selection_function=selection(),
model=model(),
Expand Down
6 changes: 3 additions & 3 deletions scaflow/graph_nodes/nodes/discriminants/max_abs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Any, Callable, Dict

import scared

Expand All @@ -17,8 +17,8 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
n.add_output(Output("discriminant", "Data"))
n.add_output(Output("discriminant", "Data", return_type="discriminant"))
return n

def execute(self, kwargs) -> Dict[str, any]:
def execute(self, kwargs) -> Callable[[Any, Any], Any]:
return scared.maxabs
7 changes: 5 additions & 2 deletions scaflow/graph_nodes/nodes/model/hamming_weight.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Type

import scared
from scared import Model

from scaflow.model.dispatcher import dispatcher
from scaflow.model.node import Node
Expand All @@ -15,8 +18,8 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
n.add_output(Output("model", "Hamming Weights"))
n.add_output(Output("model", "Hamming Weights", return_type="model"))
return n

def execute(self, kwargs):
def execute(self, kwargs) -> Type[Model]:
return scared.HammingWeight
9 changes: 7 additions & 2 deletions scaflow/graph_nodes/nodes/selection/first_sub_bytes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Type

import scared
from scared.selection_functions import SelectionFunction

from scaflow.model.dispatcher import dispatcher
from scaflow.model.node import Node
Expand All @@ -15,8 +18,10 @@ def __init__(self, name=None):
@classmethod
def create_node(cls):
n = cls()
n.add_output(Output("selection", "Data"))
n.add_output(Output("selection", "Data", return_type="selection"))
return n

def execute(self, kwargs):
def execute(self, kwargs) -> Type[SelectionFunction]:
# Disable inspection as the type checker cannot determine that the class is wrapped with SelectionFunction
# noinspection PyTypeChecker
return scared.aes.selection_functions.encrypt.FirstSubBytes
6 changes: 4 additions & 2 deletions scaflow/graph_nodes/nodes/trace_nodes/ets_trace_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from pathlib import Path
from typing import Dict, List, Type

from estraces import TraceHeaderSet
import scared

from scaflow.model.dispatcher import dispatcher
Expand All @@ -22,8 +24,8 @@ def __init__(self, name="Trace Input"):
@classmethod
def create_node(cls):
n = cls()
n.add_input(Input("filename", "Trace File"))
n.add_output(Output("traces", "Output"))
n.add_input(Input("filename", "Trace File", accepted_types=["str"]))
n.add_output(Output("traces", "Output", return_type="TraceHeaderSet"))
return n

def execute(self, kwargs):
Expand Down
9 changes: 5 additions & 4 deletions scaflow/graph_nodes/nodes/trace_nodes/npy_trace_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import estraces
from estraces import TraceHeaderSet
import numpy as np

from scaflow.model.dispatcher import dispatcher
Expand All @@ -23,10 +24,10 @@ def __init__(self, name="Trace Input"):
@classmethod
def create_node(cls):
n = cls()
n.add_input(Input("filename", "Trace File"))
n.add_input(Input("plaintext", "Plaintext File"))
n.add_input(Input("ciphertext", "Ciphertext File"))
n.add_output(Output("traces", "Output"))
n.add_input(Input("filename", "Trace File", accepted_types=["str"]))
n.add_input(Input("plaintext", "Plaintext File", accepted_types=["str"]))
n.add_input(Input("ciphertext", "Ciphertext File", accepted_types=["str"]))
n.add_output(Output("traces", "Output", return_type="TraceHeaderSet"))
return n

def execute(self, kwargs):
Expand Down
22 changes: 18 additions & 4 deletions scaflow/model/input_socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING, Type

from scaflow.model import dispatcher
from scaflow.model import Output, dispatcher
from .socket import Socket

if TYPE_CHECKING:
Expand All @@ -11,13 +11,21 @@

@dispatcher
class Input(Socket):
def __init__(self, key: str, name: str, multi_conns: bool = False) -> None:
def __init__(
self, key: str, name: str, accepted_types: List[str], multi_conns: bool = False
) -> None:
super().__init__(key, name, multi_conns)
self._control: Optional["Control"] = None
self._accepted_types = accepted_types

def __repr__(self):
return f'<Input "{self.key}">'

def compatible_with(self, socket: "Socket"):
if isinstance(socket, Output):
return socket.return_type in self._accepted_types
return False

def add_connection(self, conn: "Connection"):
if not self.multi_conns and self.has_connection():
raise Exception("Multiple connections not allowed")
Expand All @@ -30,11 +38,17 @@ def as_dict(self) -> "InputDict":
"compatible": self._compatible,
"multi_conns": self.multi_conns,
"control": self._control,
"accepted_types": self._accepted_types,
}

@classmethod
def from_dict(cls, data: "InputDict"):
c = cls(data["key"], data["name"], data["multi_conns"])
c = cls(
data["key"],
data["name"],
accepted_types=data["accepted_types"],
multi_conns=data["multi_conns"],
)
c._control = data["control"]
c._compatible = data["compatible"]
return c
32 changes: 30 additions & 2 deletions scaflow/model/output_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@
from .socket import Socket
from .connection import Connection

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Type

from .type_hints import OutputDict, SocketDict

if TYPE_CHECKING:
from .input_socket import Input


@dispatcher
class Output(Socket):
def __init__(self, key: str, name: str, multi_conns: bool = True) -> None:
def __init__(
self, key: str, name: str, return_type: str, multi_conns: bool = True
) -> None:
super().__init__(key, name, multi_conns)
self.return_type = return_type

def __repr__(self):
return f'<Output "{self.key}">'

def compatible_with(self, socket: "Socket"):
return socket.compatible_with(self)

def add_connection(self, input_socket: "Input"):
if not self.compatible_with(input_socket):
raise TypeError("Not compatible with socket")
Expand All @@ -33,3 +41,23 @@ def add_connection(self, input_socket: "Input"):
self.connections.append(connection)
input_socket.add_connection(connection)
return connection

def as_dict(self) -> OutputDict:
return {
"key": self.key,
"name": self.display_name,
"compatible": self._compatible,
"multi_conns": self.multi_conns,
"return_type": self.return_type,
}

@classmethod
def from_dict(cls, data: OutputDict):
c = cls(
data["key"],
data["name"],
return_type=data["return_type"],
multi_conns=data["multi_conns"],
)
c._compatible = data["compatible"]
return c
20 changes: 0 additions & 20 deletions scaflow/model/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def __eq__(self, other: object) -> bool:
)
return NotImplemented

def compatible_with(self, socket: "Socket"):
# logger.debug("%s, %s", self.key, socket.key)
return self.key == socket.key or socket.key in self._compatible

def has_connection(self):
return len(self.connections) > 0

Expand All @@ -55,19 +51,3 @@ def remove_connections(self):
def add_compatible(self, compatible: str):
if compatible not in self._compatible:
self._compatible.append(compatible)

def as_dict(self) -> SocketDict:
return {
"key": self.key,
"name": self.display_name,
"compatible": self._compatible,
"multi_conns": self.multi_conns,
}

@classmethod
def from_dict(cls, data: SocketDict):
socket = cls(
key=data["key"], name=data["name"], multi_conns=data["multi_conns"]
)
socket._compatible = data.get("compatible", [])
return socket
7 changes: 6 additions & 1 deletion scaflow/model/type_hints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, TypedDict
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypedDict

if TYPE_CHECKING:
from scaflow.model import Connection, Control, Input, Node, Output
Expand Down Expand Up @@ -29,6 +29,11 @@ class SocketDict(TypedDict):

class InputDict(SocketDict):
control: Optional["Control"]
accepted_types: List[str]


class OutputDict(SocketDict):
return_type: str


class NodeDict(TypedDict):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_graph_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

class TestGraphConnection:
def test_serialization(self):
o = Output("test", "Output")
i = Input("test", "Input")
o = Output("test", "Output", return_type=str)
i = Input("test", "Input", accepted_types=[str])
c = Connection(
output_socket_key=o.key, input_socket_key=i.key, input_node=0, output_node=1
)
Expand All @@ -35,8 +35,8 @@ def test_deserialization(self):
assert n.output_node == 1

def test_repr(self):
o = Output("test", "Output")
i = Input("test", "Input")
o = Output("test", "Output", return_type=str)
i = Input("test", "Input", accepted_types=[str])
c = Connection(
output_socket_key=o.key, input_socket_key=i.key, input_node=0, output_node=1
)
Expand Down
Loading

0 comments on commit c8cc29f

Please sign in to comment.