diff --git a/README.md b/README.md index c959331..6c3ac12 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ If you use NIR in your work, please cite the [following Zenodo reference](https: month = jul, year = 2023, publisher = {Zenodo}, - version = {0.0.1}, + version = {0.2}, doi = {10.5281/zenodo.8105042}, url = {https://doi.org/10.5281/zenodo.8105042} } diff --git a/docs/source/about.md b/docs/source/about.md index f504c3c..e4678d7 100644 --- a/docs/source/about.md +++ b/docs/source/about.md @@ -37,7 +37,7 @@ If you use NIR in your work, please cite the [following Zenodo reference](https: month = jul, year = 2023, publisher = {Zenodo}, - version = {0.0.1}, + version = {0.2}, doi = {10.5281/zenodo.8105042}, url = {https://doi.org/10.5281/zenodo.8105042} } diff --git a/nir/ir.py b/nir/ir.py index db1b544..30deece 100644 --- a/nir/ir.py +++ b/nir/ir.py @@ -14,14 +14,76 @@ def _parse_shape_argument(x: Types, key: str): + """Parse the shape argument of a NIR node.""" if isinstance(x, np.ndarray): return {key: x} elif isinstance(x, Sequence): return {key: np.array(x)} elif isinstance(x, dict): return x + elif x is None: + return {key: None} + + +def _calculate_conv_output( + input_shape: typing.Union[int, typing.Sequence[int]], + padding: typing.Union[int, str, typing.Sequence[int]], + dilation: typing.Union[int, typing.Sequence[int]], + kernel_size: typing.Union[int, typing.Sequence[int]], + stride: typing.Union[int, typing.Sequence[int]], +) -> typing.Sequence[int]: + """Calculates the output for a single dimension of a convolutional layer. + https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d + + :param input_shape: input shape, either int or (int, int) + :type input_shape: int | typing.Sequence[int] + :param padding: padding + :type padding: int | typing.Sequence[int] + :param dilation: dilation + :type dilation: int | typing.Sequence[int] + :param kernel_size: kernel size + :type kernel_size: int | typing.Sequence[int] + :param stride: stride + :type stride: int | typing.Sequence[int] + + :return: output shape + :rtype: typing.Sequence[int] + """ + if isinstance(input_shape, (int, np.integer)): + ndim = 1 + else: + ndim = len(input_shape) + if isinstance(padding, str) and padding == 'valid': + padding = [0] * ndim + shapes = [] + for i in range(ndim): + if isinstance(padding, str) and padding == 'same': + shape = input_shape[i] + else: + shape = np.floor( + ( + _index_tuple(input_shape, i) + + 2 * _index_tuple(padding, i) + - _index_tuple(dilation, i) * (_index_tuple(kernel_size, i) - 1) + - 1 + ) + / _index_tuple(stride, i) + + 1 + ) + shapes.append(int(shape)) + return np.array(shapes) + + +def _index_tuple( + tuple: typing.Union[int, typing.Sequence[int]], index: int +) -> typing.Union[int, np.ndarray]: + """If the input is a tuple/array, index it. Otherwise, return it as-is.""" + if isinstance(tuple, np.ndarray) or isinstance(tuple, Sequence): + return tuple[index] + elif isinstance(tuple, (int, np.integer)): + return np.array([tuple]) else: - raise ValueError("Unknown shape argument", x) + raise TypeError(f"tuple must be int or np.ndarray, not {type(tuple)}") @dataclass(eq=False) @@ -114,6 +176,136 @@ def __post_init__(self): self.output_type = { node_key: self.nodes[node_key].output_type for node_key in output_node_keys } + + def _check_types(self): + """Check that all nodes in the graph have input and output types. Will raise ValueError + if any node has no input or output type, or if the types are inconsistent.""" + for edge in self.edges: + pre_node = self.nodes[edge[0]] + post_node = self.nodes[edge[1]] + + # make sure all types are defined + undef_out_type = pre_node.output_type is None or any( + v is None for v in pre_node.output_type.values() + ) + if undef_out_type: + raise ValueError(f'pre node {edge[0]} has no output type') + undef_in_type = post_node.input_type is None or any( + v is None for v in post_node.input_type.values() + ) + if undef_in_type: + raise ValueError(f'post node {edge[1]} has no input type') + + # make sure the length of types is equal + if len(pre_node.output_type) != len(post_node.input_type): + pre_repr = f'len({edge[0]}.output)={len(pre_node.output_type)}' + post_repr = f'len({edge[1]}.input)={len(post_node.input_type)}' + raise ValueError(f'type length mismatch: {pre_repr} -> {post_repr}') + + # make sure the type values match up + if len(pre_node.output_type.keys()) == 1: + post_input_type = list(post_node.input_type.values())[0] + pre_output_type = list(pre_node.output_type.values())[0] + if not np.array_equal(post_input_type, pre_output_type): + pre_repr = f'{edge[0]}.output: {pre_output_type}' + post_repr = f'{edge[1]}.input: {post_input_type}' + raise ValueError(f'type mismatch: {pre_repr} -> {post_repr}') + else: + raise NotImplementedError('multiple input/output types not supported yet') + return True + + def _forward_type_inference(self, debug=True): + """Infer the types of all nodes in this graph. Will modify the input_type and output_type + of nodes in the graph as needed. Assumes that the input_type of the graph is set. Moves + from the input nodes to the output nodes. Raises ValueError if types are inconsistent. + + Assumes that all input types are of form: {'input': ...} and all output types are of form: + {'output': ...}. + + Currently only supports the inference of output types for Conv1d and Conv2d nodes. + Does not support nested NIR graphs. + """ + ready = [e for e in self.edges if e[0] in self.inputs.keys()] + seen = set([e[0] for e in ready]) + while len(ready) > 0: + pre_key, post_key = ready.pop() + pre_node = self.nodes[pre_key] + post_node = self.nodes[post_key] + + if isinstance(pre_node, NIRGraph) or isinstance(post_node, NIRGraph): + raise NotImplementedError('type inference on nested NIR graphs not supported yet') + + # check if post input_type needs to be defined + undef_post_input_type = post_node.input_type is None or any( + v is None for v in post_node.input_type.values() + ) + type_mismatch = any([ + len(post_node.input_type) != len(pre_node.output_type), + not np.array_equal( + np.array(list(pre_node.output_type.values())), + np.array(list(post_node.input_type.values())) + ) + ]) + if undef_post_input_type: + # define post input_type to be the same as pre output_type + print(f'[warning] {post_key}.input_type undefined, set to {pre_key}.output_type') + post_node.input_type = { + k.replace('output', 'input'): v for k, v in pre_node.output_type.items() + } + elif type_mismatch: + # set post input_type to be the same as pre output_type + pre_repr = f'{pre_key}.output: {np.array(list(pre_node.output_type.values()))}' + post_repr = f'{post_key}.input: {np.array(list(post_node.input_type.values()))}' + print(f'[warning] overwriting {post_repr} with {pre_repr}') + post_node.input_type = { + k.replace('output', 'input'): v for k, v in pre_node.output_type.items() + } + + # check if post output_type needs to be defined + undef_post_output_type = post_node.output_type is None or any( + v is None for v in post_node.output_type.values() + ) + if undef_post_output_type: + # define post output_type + if isinstance(post_node, Conv1d) or isinstance(post_node, Conv2d): + if isinstance(post_node, Conv1d): + post_node.input_shape = post_node.input_type['input'][1] + else: + post_node.input_shape = tuple(post_node.input_type['input'][1:]) + output_shape = _calculate_conv_output( + post_node.input_shape, + post_node.padding, + post_node.dilation, + post_node.weight.shape[2], + post_node.stride, + ) + output_type = np.array([post_node.weight.shape[0], *output_shape]) + post_node.output_type = {"output": output_type} + + seen.add(post_key) + ready += [e for e in self.edges if e[0] == post_key and e[1] not in seen] + + def infer_types(self): + """Infer the shapes of all nodes in this graph. Will modify the input_type and + output_type of all nodes in the graph. + + Assumes that either the input type or the output type of the graph is set. + Assumes that if A->B, then A.output_type.values() = B.input_type.values() + """ + undef_input_type = self.input_type is None or any( + v is None for v in self.input_type.values() + ) + undef_output_type = self.output_type is None or any( + v is None for v in self.output_type.values() + ) + if not undef_input_type: + # forward-mode type inferring + self._forward_type_inference() + elif not undef_output_type: + # backward-mode type inferring + raise NotImplementedError('backward-mode type inference not implemented yet') + else: + raise ValueError("Either input_type or output_type must be set") @dataclass(eq=False) @@ -145,53 +337,119 @@ def __post_init__(self): @dataclass(eq=False) class Conv1d(NIRNode): - """Convolutional layer in 1d.""" + """Convolutional layer in 1d. + + Note that the input_shape argument is required to disambiguate the shape, and is used + to infer the exact output shape along with the other parameters. If the input_shape + is None, the output shape will also be None. + + The NIRGraph.infer_all_shapes function may be used to automatically infer the input and + output types on the graph level. + + :param input_shape: Shape of spatial input (N,) + :type input_shape: Optional[int] + :param weight: Weight, shape (C_out, C_in, N) + :type weight: np.ndarray + :param stride: Stride + :type stride: int + :param padding: Padding, if string must be 'same' or 'valid' + :type padding: int | str + :param dilation: Dilation + :type dilation: int + :param groups: Groups + :type groups: int + :param bias: Bias array of shape (C_out,) + :type bias: np.ndarray + """ - weight: np.ndarray # Weight C_out * C_in * X + input_shape: typing.Optional[int] # N + weight: np.ndarray # Weight C_out * C_in * N stride: int # Stride - padding: int # Padding + padding: typing.Union[int, str] # Padding dilation: int # Dilation groups: int # Groups bias: np.ndarray # Bias C_out def __post_init__(self): - self.input_type = {"input": np.array(self.weight.shape)[1:]} - self.output_type = {"output": np.array(self.weight.shape)[[0, 2]]} + if isinstance(self.padding, str) and self.padding not in ["same", "valid"]: + raise ValueError(f"padding must be 'same', 'valid', or int, not {self.padding}") + if self.input_shape is None: + # leave input and output types undefined + self.input_type = {"input": None} + self.output_type = {"output": None} + else: + # infer input and output types from input_shape + self.input_type = {"input": np.array([self.weight.shape[1], self.input_shape])} + output_shape = _calculate_conv_output( + self.input_shape, + self.padding, + self.dilation, + self.weight.shape[2], + self.stride, + ) + self.output_type = {"output": np.array([self.weight.shape[0], *output_shape])} @dataclass(eq=False) class Conv2d(NIRNode): - """Convolutional layer in 2d.""" + """Convolutional layer in 2d. + + Note that the input_shape argument is required to disambiguate the shape, and is used + to infer the exact output shape along with the other parameters. If the input_shape + is None, the output shape will also be None. + + The NIRGraph.infer_all_shapes function may be used to automatically infer the input and + output types on the graph level. + + :param input_shape: Shape of spatial input (N_x, N_y) + :type input_shape: Optional[tuple[int, int]] + :param weight: Weight, shape (C_out, C_in, N_x, N_y) + :type weight: np.ndarray + :param stride: Stride + :type stride: int | int, int + :param padding: Padding, if string must be 'same' or 'valid' + :type padding: int | int, int | str + :param dilation: Dilation + :type dilation: int | int, int + :param groups: Groups + :type groups: int + :param bias: Bias array of shape (C_out,) + :type bias: np.ndarray + """ - weight: np.ndarray # Weight C_out * C_in * X * Y - stride: int # Stride - padding: int # Padding - dilation: int # Dilation + # Shape of input tensor (overrrides input_type from + input_shape: typing.Optional[typing.Tuple[int, int]] # N_x, N_y + weight: np.ndarray # Weight C_out * C_in * W_x * W_y + stride: typing.Union[int, typing.Tuple[int, int]] # Stride + padding: typing.Union[int, typing.Tuple[int, int], str] # Padding + dilation: typing.Union[int, typing.Tuple[int, int]] # Dilation groups: int # Groups bias: np.ndarray # Bias C_out def __post_init__(self): - if isinstance(self.stride, int): - self.stride = (self.stride, self.stride) + if isinstance(self.padding, str) and self.padding not in ["same", "valid"]: + raise ValueError(f"padding must be 'same', 'valid', or int, not {self.padding}") if isinstance(self.padding, int): self.padding = (self.padding, self.padding) + if isinstance(self.stride, int): + self.stride = (self.stride, self.stride) if isinstance(self.dilation, int): self.dilation = (self.dilation, self.dilation) - self.input_type = {"input": np.array(self.weight.shape)[1:]} - self.output_type = {"output": np.array(self.weight.shape)[[0, 2, 3]]} - - -@dataclass(eq=False) -class SumPool2d(NIRNode): - """Sum pooling layer in 2d.""" - - kernel_size: np.ndarray # (Height, Width) - stride: np.ndarray # (Height, width) - padding: np.ndarray # (Height, width) - - def __post_init__(self): - self.input_type = {"input": ()} - self.output_type = {"output": ()} + if self.input_shape is None: + # leave input and output types undefined + self.input_type = {"input": None} + self.output_type = {"output": None} + else: + # infer input and output types from input_shape + self.input_type = {"input": np.array([self.weight.shape[1], *self.input_shape])} + output_shape = _calculate_conv_output( + self.input_shape, + self.padding, + self.dilation, + self.weight.shape[2], + self.stride, + ) + self.output_type = {"output": np.array([self.weight.shape[0], *output_shape])} @dataclass(eq=False) @@ -486,6 +744,19 @@ def __post_init__(self): self.output_type = {"output": np.array(self.scale.shape)} +@dataclass(eq=False) +class SumPool2d(NIRNode): + """Sum pooling layer in 2d.""" + + kernel_size: np.ndarray # (Height, Width) + stride: np.ndarray # (Height, width) + padding: np.ndarray # (Height, width) + + def __post_init__(self): + self.input_type = {"input": ()} + self.output_type = {"output": ()} + + @dataclass(eq=False) class Threshold(NIRNode): r"""Threshold node. diff --git a/nir/read.py b/nir/read.py index 5e263b5..cc830b8 100644 --- a/nir/read.py +++ b/nir/read.py @@ -12,6 +12,7 @@ def read_node(node: typing.Any) -> nir.NIRNode: return nir.Affine(weight=node["weight"][()], bias=node["bias"][()]) elif node["type"][()] == b"Conv1d": return nir.Conv1d( + input_shape=node["input_shape"][()], weight=node["weight"][()], stride=node["stride"][()], padding=node["padding"][()], @@ -21,6 +22,7 @@ def read_node(node: typing.Any) -> nir.NIRNode: ) elif node["type"][()] == b"Conv2d": return nir.Conv2d( + input_shape=node["input_shape"][()], weight=node["weight"][()], stride=node["stride"][()], padding=node["padding"][()], diff --git a/nir/write.py b/nir/write.py index 5715d3f..57a0661 100644 --- a/nir/write.py +++ b/nir/write.py @@ -17,6 +17,7 @@ def _convert_node(node: nir.NIRNode) -> dict: elif isinstance(node, nir.Conv1d): return { "type": "Conv1d", + "input_shape": node.input_shape, "weight": node.weight, "stride": node.stride, "padding": node.padding, @@ -27,6 +28,7 @@ def _convert_node(node: nir.NIRNode) -> dict: elif isinstance(node, nir.Conv2d): return { "type": "Conv2d", + "input_shape": node.input_shape, "weight": node.weight, "stride": node.stride, "padding": node.padding, diff --git a/tests/__init__.py b/tests/__init__.py index b15cf84..f0695d8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,6 +7,29 @@ def mock_linear(*shape): return nir.Linear(weight=np.random.randn(*shape).T) +def mock_conv(input_shape, weights): + if len(weights) == 4: + return nir.Conv2d( + input_shape=input_shape, + weight=np.random.randn(*weights), + bias=np.random.randn(weights[0]), + stride=1, + padding=0, + dilation=1, + groups=1, + ) + else: + return nir.Conv1d( + input_shape=input_shape, + weight=np.random.randn(*weights), + bias=np.random.randn(weights[0]), + stride=1, + padding=0, + dilation=1, + groups=1, + ) + + def mock_affine(*shape): return nir.Affine(weight=np.random.randn(*shape).T, bias=np.random.randn(shape[1])) diff --git a/tests/test_ir.py b/tests/test_ir.py index bc173b5..30cff5f 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -98,6 +98,38 @@ def test_delay(): assert ir.edges == [("in", "d"), ("d", "out")] +def test_conv1d(): + w = np.random.randn(2, 1, 3) + a = nir.Conv1d( + input_shape=100, + weight=w, + stride=2, + dilation=1, + groups=1, + padding=1, + bias=np.ndarray([1]), + ) + assert np.allclose(a.weight, w) + assert np.allclose(a.input_shape, 100) + assert np.allclose(a.output_type["output"], np.array([2, 50])) + + +def test_conv2d(): + w = np.random.randn(3, 1, 3, 3) + a = nir.Conv2d( + input_shape=(100, 100), + weight=w, + padding=(1, 1), + stride=(1, 2), + dilation=(1, 1), + groups=(1, 1), + bias=np.ndarray([1]), + ) + assert np.allclose(a.weight, w) + assert np.allclose(a.input_shape, np.array([100, 100])) + assert np.allclose(a.output_type["output"], np.array([3, 100, 50])) + + def test_cuba_lif(): a = np.random.randn(10, 10) lif = nir.CubaLIF(tau_mem=a, tau_syn=a, r=a, v_leak=a, v_threshold=a) @@ -295,3 +327,108 @@ def test_inputs_outputs_properties(): assert ir.nodes["in2"] in ir2.nodes["inner"].inputs.values() assert ir.nodes["out1"] in ir2.nodes["inner"].outputs.values() assert ir.nodes["out2"] in ir2.nodes["inner"].outputs.values() + + +def test_conv_type_inference(): + graphs = { + 'undef graph output': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64, 64])), + 'conv': nir.Conv2d( + input_shape=(64, 64), + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=None) + }, edges=[('input', 'conv'), ('conv', 'output')]), + + 'incorrect graph output': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64, 64])), + 'conv': nir.Conv2d( + input_shape=(64, 64), + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=np.array([1, 61, 1])) + }, edges=[('input', 'conv'), ('conv', 'output')]), + + 'undef conv.input': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64, 64])), + 'conv': nir.Conv2d( + input_shape=None, + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=np.array([1, 61, 61])) + }, edges=[('input', 'conv'), ('conv', 'output')]), + + 'undef conv.input and graph output': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64, 64])), + 'conv': nir.Conv2d( + input_shape=None, + weight=np.zeros((1, 1, 4, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=None) + }, edges=[('input', 'conv'), ('conv', 'output')]), + + 'Conv1d undef graph output': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64])), + 'conv': nir.Conv1d( + input_shape=64, + weight=np.zeros((1, 1, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=None) + }, edges=[('input', 'conv'), ('conv', 'output')]), + + 'Conv1d incorrect graph output': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64])), + 'conv': nir.Conv1d( + input_shape=64, + weight=np.zeros((1, 1, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=np.array([1, 3])) + }, edges=[('input', 'conv'), ('conv', 'output')]), + + 'Conv1d undef conv.input and graph output': nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=np.array([1, 64])), + 'conv': nir.Conv1d( + input_shape=None, + weight=np.zeros((1, 1, 4)), + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None + ), + 'output': nir.Output(output_type=None) + }, edges=[('input', 'conv'), ('conv', 'output')]), + } + for name, graph in graphs.items(): + graph.infer_types() + assert graph._check_types(), f'type inference failed for: {name}' diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 81c961f..5f7d2d4 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -3,7 +3,7 @@ import numpy as np import nir -from tests import mock_affine +from tests import mock_affine, mock_conv def assert_equivalence(ir: nir.NIRGraph, ir2: nir.NIRGraph): @@ -63,6 +63,22 @@ def test_nested(): factory_test_graph(nested) +def test_conv1d(): + ir = nir.NIRGraph.from_list( + mock_affine(2, 100), + mock_conv(100, (1, 2, 3)), + mock_affine(100, 2), + ) + factory_test_graph(ir) + + +def test_conv1d_2(): + ir = nir.NIRGraph.from_list( + mock_conv((100, 100), (1, 2, 3, 3)), + ) + factory_test_graph(ir) + + def test_integrator(): r = np.array([1, 1, 1]) ir = nir.NIRGraph(