Skip to content

Commit f536e9e

Browse files
committed
Parameter, check unique dim tags
#48
1 parent 77ee103 commit f536e9e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

nn/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ def __init__(self, shape: Sequence[Dim], dtype: str = "float32"):
316316
raise TypeError(f"shape {shape} must be a sequence of Dim")
317317
if not all(isinstance(dim.dimension, int) for dim in shape):
318318
raise ValueError(f"shape {shape} must be static")
319+
if sorted(shape) != sorted(set(shape)):
320+
raise ValueError(f"shape {shape} dims must be unique")
319321
# Note: At creation time, we don't know the name yet.
320322
# The name will be inferred by the parent modules and the attribute chain.
321323
name_ctx = NameCtx(name="parameter", parent=None) # this is incomplete and will be configured later

0 commit comments

Comments
 (0)