Skip to content

Commit

Permalink
Feature check types (#124)
Browse files Browse the repository at this point in the history
* Add: Call _check_types in init of NIRGraph, NIRGraph.from_list, and NIRGraph.from_dict
* Added pytest to nix flake
Throws `AttributeError` on call `NIRNode.__init__`.
* Added a vscode devcontainer
---------

Co-authored-by: Ben Kroehs <[email protected]>
Co-authored-by: Ben Kroehs <[email protected]>
  • Loading branch information
3 people authored Feb 13, 2025
1 parent c9af31f commit c40ee70
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 101 deletions.
4 changes: 4 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:1-3.12-bullseye

RUN pip install numpy black ruff nir pytest
RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu
30 changes: 30 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
"name": "Python 3",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"build": {
"dockerfile": "Dockerfile"
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.black-formatter",
"ms-azuretools.vscode-docker",
"ms-toolsai.jupyter"
]
}
}

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "pip3 install -e .",

// Configure tool-specific properties.>
// "customizations": {}

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pythonPackages.numpy
pythonPackages.h5py
pythonPackages.black
pythonPackages.pytest
pkgs.ruff
pkgs.autoPatchelfHook
];
Expand Down
109 changes: 59 additions & 50 deletions nir/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ class NIRGraph(NIRNode):
edges.
A graph of computational nodes and identity edges.
Arguments:
nodes: Dictionary of nodes in the graph.
edges: List of edges in the graph.
metadata: Dictionary of metadata for the graph.
type_check: Whether to check that input and output types match for all nodes in the graph.
Will not be stored in the graph as an attribute. Defaults to True.
"""

nodes: Nodes # List of computational nodes
Expand All @@ -31,6 +38,28 @@ class NIRGraph(NIRNode):
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __init__(
self,
nodes: Nodes,
edges: Edges,
input_type: Optional[Dict[str, np.ndarray]] = None,
output_type: Optional[Dict[str, np.ndarray]] = None,
metadata: Dict[str, Any] = dict,
type_check: bool = True,
):
self.nodes = nodes
self.edges = edges
self.metadata = metadata
self.input_type = input_type
self.output_type = output_type

# Check that all nodes have input and output types, if requested (default)
if type_check:
self._check_types()

# Call post init to set input_type and output_type
self.__post_init__()

@property
def inputs(self):
return {
Expand All @@ -44,7 +73,7 @@ def outputs(self):
}

@staticmethod
def from_list(*nodes: NIRNode) -> "NIRGraph":
def from_list(*nodes: NIRNode, type_check: bool = True) -> "NIRGraph":
"""Create a sequential graph from a list of nodes by labelling them after
indices."""

Expand Down Expand Up @@ -81,80 +110,58 @@ def unique_node_name(node, counts):
return NIRGraph(
nodes=node_dict,
edges=edges,
type_check=type_check,
)

def __post_init__(self):
input_node_keys = [
k for k, node in self.nodes.items() if isinstance(node, Input)
]
self.input_type = (
{node_key: self.nodes[node_key].input_type for node_key in input_node_keys}
{
node_key: self.nodes[node_key].input_type["input"]
for node_key in input_node_keys
}
if len(input_node_keys) > 0
else None
)
output_node_keys = [
k for k, node in self.nodes.items() if isinstance(node, Output)
]
self.output_type = {
node_key: self.nodes[node_key].output_type for node_key in output_node_keys
node_key: self.nodes[node_key].output_type["output"]
for node_key in output_node_keys
}
# Assign the metadata attribute if left unset to avoid issues with serialization
if not isinstance(self.metadata, dict):
self.metadata = {}

def to_dict(self) -> Dict[str, Any]:
ret = super().to_dict()
ret["nodes"] = {k: n.to_dict() for k, n in self.nodes.items()}
return ret

@classmethod
def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
def from_dict(cls, kwargs: Dict[str, Any]) -> "NIRGraph":
from . import dict2NIRNode

node["nodes"] = {k: dict2NIRNode(n) for k, n in node["nodes"].items()}
# h5py deserializes edges into a numpy array of type bytes and dtype=object,
# hence using ensure_str here
node["edges"] = [(ensure_str(a), ensure_str(b)) for a, b in node["edges"]]
return super().from_dict(node)
kwargs_local = kwargs.copy() # Copy the input to avoid overwriting attributes

# Assert that we have nodes and edges
assert "nodes" in kwargs, "The incoming dictionary must hade a 'nodes' entry"
assert "edges" in kwargs, "The incoming dictionary must hade a 'edges' entry"
# Assert that the type is well-formed
if "type" in kwargs:
assert kwargs["type"] == "NIRGraph", "You are calling NIRGraph.from_dict with a different type "
f"{type}. Either remove the entry or use <Specific NIRNode>.from_dict, such as Input.from_dict"
kwargs_local["type"] = "NIRGraph"

def _check_types(self):
"""Check that all nodes in the graph have input and output types.
Will raise ValueError if any node has no input or output type, or if the types
are inconsistent.
"""
for edge in self.edges:
pre_node = self.nodes[edge[0]]
post_node = self.nodes[edge[1]]

# make sure all types are defined
undef_out_type = pre_node.output_type is None or any(
v is None for v in pre_node.output_type.values()
)
if undef_out_type:
raise ValueError(f"pre node {edge[0]} has no output type")
undef_in_type = post_node.input_type is None or any(
v is None for v in post_node.input_type.values()
)
if undef_in_type:
raise ValueError(f"post node {edge[1]} has no input type")

# make sure the length of types is equal
if len(pre_node.output_type) != len(post_node.input_type):
pre_repr = f"len({edge[0]}.output)={len(pre_node.output_type)}"
post_repr = f"len({edge[1]}.input)={len(post_node.input_type)}"
raise ValueError(f"type length mismatch: {pre_repr} -> {post_repr}")

# make sure the type values match up
if len(pre_node.output_type.keys()) == 1:
post_input_type = list(post_node.input_type.values())[0]
pre_output_type = list(pre_node.output_type.values())[0]
if not np.array_equal(post_input_type, pre_output_type):
pre_repr = f"{edge[0]}.output: {pre_output_type}"
post_repr = f"{edge[1]}.input: {post_input_type}"
raise ValueError(f"type mismatch: {pre_repr} -> {post_repr}")
else:
raise NotImplementedError(
"multiple input/output types not supported yet"
)
return True
kwargs_local["nodes"] = {k: dict2NIRNode(n) for k, n in kwargs_local["nodes"].items()}
# h5py deserializes edges into a numpy array of type bytes and dtype=object,
# hence using ensure_str here
kwargs_local["edges"] = [(ensure_str(a), ensure_str(b)) for a, b in kwargs_local["edges"]]
return super().from_dict(kwargs_local)

def _forward_type_inference(self, debug=True):
"""Infer the types of all nodes in this graph. Will modify the input_type and
Expand Down Expand Up @@ -497,12 +504,14 @@ def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
del node["shape"]
return super().from_dict(node)


@dataclass(eq=False)
class Identity(NIRNode):
"""Identity Node.
This is a virtual node, which allows for the identity operation.
"""

input_type: Types

def __post_init__(self):
Expand All @@ -515,4 +524,4 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
return super().from_dict(node)
return super().from_dict(node)
8 changes: 4 additions & 4 deletions nir/ir/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def to_dict(self) -> Dict[str, Any]:
return ret

@classmethod
def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
assert node["type"] == cls.__name__
del node["type"]
def from_dict(cls, kwargs: Dict[str, Any]) -> "NIRNode":
assert kwargs["type"] == cls.__name__
del kwargs["type"]

return cls(**node)
return cls(**kwargs)
Loading

0 comments on commit c40ee70

Please sign in to comment.