Skip to content

Commit 1ac99b0

Browse files
Cristian GarciaFlax Authors
authored andcommitted
backfill eager_sharding in set_metadata
PiperOrigin-RevId: 832407389
1 parent b5db513 commit 1ac99b0

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

flax/nnx/variablelib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,13 @@ def set_metadata(self, *args, **kwargs) -> None:
12611261
f'Cannot change `is_mutable` metadata, expected {self.is_mutable}, '
12621262
f'got {metadata["is_mutable"]}'
12631263
)
1264+
if 'eager_sharding' not in metadata:
1265+
metadata['eager_sharding'] = self.eager_sharding
1266+
if metadata['eager_sharding'] != self.eager_sharding:
1267+
raise ValueError(
1268+
f'Cannot change `eager_sharding` metadata, expected '
1269+
f'{self.eager_sharding}, got {metadata["eager_sharding"]}'
1270+
)
12641271
self._var_metadata = metadata
12651272
elif len(args) == 2:
12661273
name, value = args

0 commit comments

Comments
 (0)