Skip to content

Commit

Permalink
Specify spatial shapes in convolutions (#59)
Browse files Browse the repository at this point in the history
* Bumped version in zenodo ref
* Added shape specification for convolutional layers
* make types optional + 'same' & 'valid' padding
* add type inference to NIRGraph (limited support)
* add tests for type inference
* linting

---------

Co-authored-by: Steve Abreu <[email protected]>
  • Loading branch information
Jegp and stevenabreu7 authored Oct 11, 2023
1 parent 8c63f77 commit 274748d
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ If you use NIR in your work, please cite the [following Zenodo reference](https:
month = jul,
year = 2023,
publisher = {Zenodo},
version = {0.0.1},
version = {0.2},
doi = {10.5281/zenodo.8105042},
url = {https://doi.org/10.5281/zenodo.8105042}
}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/about.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ If you use NIR in your work, please cite the [following Zenodo reference](https:
month = jul,
year = 2023,
publisher = {Zenodo},
version = {0.0.1},
version = {0.2},
doi = {10.5281/zenodo.8105042},
url = {https://doi.org/10.5281/zenodo.8105042}
}
Expand Down
327 changes: 299 additions & 28 deletions nir/ir.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions nir/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def read_node(node: typing.Any) -> nir.NIRNode:
return nir.Affine(weight=node["weight"][()], bias=node["bias"][()])
elif node["type"][()] == b"Conv1d":
return nir.Conv1d(
input_shape=node["input_shape"][()],
weight=node["weight"][()],
stride=node["stride"][()],
padding=node["padding"][()],
Expand All @@ -21,6 +22,7 @@ def read_node(node: typing.Any) -> nir.NIRNode:
)
elif node["type"][()] == b"Conv2d":
return nir.Conv2d(
input_shape=node["input_shape"][()],
weight=node["weight"][()],
stride=node["stride"][()],
padding=node["padding"][()],
Expand Down
2 changes: 2 additions & 0 deletions nir/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def _convert_node(node: nir.NIRNode) -> dict:
elif isinstance(node, nir.Conv1d):
return {
"type": "Conv1d",
"input_shape": node.input_shape,
"weight": node.weight,
"stride": node.stride,
"padding": node.padding,
Expand All @@ -27,6 +28,7 @@ def _convert_node(node: nir.NIRNode) -> dict:
elif isinstance(node, nir.Conv2d):
return {
"type": "Conv2d",
"input_shape": node.input_shape,
"weight": node.weight,
"stride": node.stride,
"padding": node.padding,
Expand Down
23 changes: 23 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@ def mock_linear(*shape):
return nir.Linear(weight=np.random.randn(*shape).T)


def mock_conv(input_shape, weights):
if len(weights) == 4:
return nir.Conv2d(
input_shape=input_shape,
weight=np.random.randn(*weights),
bias=np.random.randn(weights[0]),
stride=1,
padding=0,
dilation=1,
groups=1,
)
else:
return nir.Conv1d(
input_shape=input_shape,
weight=np.random.randn(*weights),
bias=np.random.randn(weights[0]),
stride=1,
padding=0,
dilation=1,
groups=1,
)


def mock_affine(*shape):
return nir.Affine(weight=np.random.randn(*shape).T, bias=np.random.randn(shape[1]))

Expand Down
137 changes: 137 additions & 0 deletions tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,38 @@ def test_delay():
assert ir.edges == [("in", "d"), ("d", "out")]


def test_conv1d():
w = np.random.randn(2, 1, 3)
a = nir.Conv1d(
input_shape=100,
weight=w,
stride=2,
dilation=1,
groups=1,
padding=1,
bias=np.ndarray([1]),
)
assert np.allclose(a.weight, w)
assert np.allclose(a.input_shape, 100)
assert np.allclose(a.output_type["output"], np.array([2, 50]))


def test_conv2d():
w = np.random.randn(3, 1, 3, 3)
a = nir.Conv2d(
input_shape=(100, 100),
weight=w,
padding=(1, 1),
stride=(1, 2),
dilation=(1, 1),
groups=(1, 1),
bias=np.ndarray([1]),
)
assert np.allclose(a.weight, w)
assert np.allclose(a.input_shape, np.array([100, 100]))
assert np.allclose(a.output_type["output"], np.array([3, 100, 50]))


def test_cuba_lif():
a = np.random.randn(10, 10)
lif = nir.CubaLIF(tau_mem=a, tau_syn=a, r=a, v_leak=a, v_threshold=a)
Expand Down Expand Up @@ -295,3 +327,108 @@ def test_inputs_outputs_properties():
assert ir.nodes["in2"] in ir2.nodes["inner"].inputs.values()
assert ir.nodes["out1"] in ir2.nodes["inner"].outputs.values()
assert ir.nodes["out2"] in ir2.nodes["inner"].outputs.values()


def test_conv_type_inference():
graphs = {
'undef graph output': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64, 64])),
'conv': nir.Conv2d(
input_shape=(64, 64),
weight=np.zeros((1, 1, 4, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=None)
}, edges=[('input', 'conv'), ('conv', 'output')]),

'incorrect graph output': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64, 64])),
'conv': nir.Conv2d(
input_shape=(64, 64),
weight=np.zeros((1, 1, 4, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=np.array([1, 61, 1]))
}, edges=[('input', 'conv'), ('conv', 'output')]),

'undef conv.input': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64, 64])),
'conv': nir.Conv2d(
input_shape=None,
weight=np.zeros((1, 1, 4, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=np.array([1, 61, 61]))
}, edges=[('input', 'conv'), ('conv', 'output')]),

'undef conv.input and graph output': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64, 64])),
'conv': nir.Conv2d(
input_shape=None,
weight=np.zeros((1, 1, 4, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=None)
}, edges=[('input', 'conv'), ('conv', 'output')]),

'Conv1d undef graph output': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64])),
'conv': nir.Conv1d(
input_shape=64,
weight=np.zeros((1, 1, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=None)
}, edges=[('input', 'conv'), ('conv', 'output')]),

'Conv1d incorrect graph output': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64])),
'conv': nir.Conv1d(
input_shape=64,
weight=np.zeros((1, 1, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=np.array([1, 3]))
}, edges=[('input', 'conv'), ('conv', 'output')]),

'Conv1d undef conv.input and graph output': nir.NIRGraph(nodes={
'input': nir.Input(input_type=np.array([1, 64])),
'conv': nir.Conv1d(
input_shape=None,
weight=np.zeros((1, 1, 4)),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None
),
'output': nir.Output(output_type=None)
}, edges=[('input', 'conv'), ('conv', 'output')]),
}
for name, graph in graphs.items():
graph.infer_types()
assert graph._check_types(), f'type inference failed for: {name}'
18 changes: 17 additions & 1 deletion tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import nir
from tests import mock_affine
from tests import mock_affine, mock_conv


def assert_equivalence(ir: nir.NIRGraph, ir2: nir.NIRGraph):
Expand Down Expand Up @@ -63,6 +63,22 @@ def test_nested():
factory_test_graph(nested)


def test_conv1d():
ir = nir.NIRGraph.from_list(
mock_affine(2, 100),
mock_conv(100, (1, 2, 3)),
mock_affine(100, 2),
)
factory_test_graph(ir)


def test_conv1d_2():
ir = nir.NIRGraph.from_list(
mock_conv((100, 100), (1, 2, 3, 3)),
)
factory_test_graph(ir)


def test_integrator():
r = np.array([1, 1, 1])
ir = nir.NIRGraph(
Expand Down

0 comments on commit 274748d

Please sign in to comment.