Skip to content

Commit cf96547

Browse files
authored
Add nnx flax update (#114)
1 parent fe52c50 commit cf96547

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/metrax/nnx/nnx_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ def update(self, **kwargs) -> None:
3232

3333
def compute(self):
3434
return self.clu_metric.compute()
35+
36+
def __init_subclass__(cls, **kwargs):
37+
super().__init_subclass__(pytree=False, **kwargs)

0 commit comments

Comments
 (0)