Skip to content

Commit 4d1cfbb

Browse files
author
Flax Authors
committed
Merge pull request #5032 from google:fix-py314-kw-only-dataclasses
PiperOrigin-RevId: 822149189
2 parents bb3c31a + 6622efc commit 4d1cfbb

File tree

4 files changed

+35
-7
lines changed

4 files changed

+35
-7
lines changed

.github/workflows/flax_test.yml

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
runs-on: ubuntu-latest
6464
strategy:
6565
matrix:
66-
python-version: ['3.11', '3.12', '3.13']
66+
python-version: ['3.11', '3.12', '3.13', '3.14']
6767
steps:
6868
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
6969
- name: Set up Python ${{ matrix.python-version }}
@@ -159,3 +159,30 @@ jobs:
159159
"description": "'$status'",
160160
"context": "github-actions/Build"
161161
}'
162+
163+
# This is a temporary workflow to test flax on Python 3.14 and
164+
# skipping deps like tensorstore, tensorflow etc
165+
tests-python314:
166+
name: Run Tests on Python 3.14
167+
needs: [pre-commit, commit-count]
168+
runs-on: ubuntu-24.04-16core
169+
steps:
170+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
171+
- name: Setup uv
172+
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
173+
with:
174+
version: "0.9.2"
175+
python-version: "3.14"
176+
activate-environment: true
177+
enable-cache: true
178+
179+
- name: Install dependencies
180+
run: |
181+
uv pip install jax msgpack treescope rich typing_extensions PyYAML optax cloudpickle
182+
uv pip install pytest pytest-custom_exit_code pytest-xdist pytest-cov
183+
uv pip install -e . --no-deps
184+
- name: Test with pytest
185+
run: |
186+
export XLA_FLAGS='--xla_force_host_platform_device_count=4'
187+
find tests/ -name "*.py" | grep -vE 'checkpoint|integr|io|tensorboard' | xargs pytest -n auto
188+

flax/linen/kw_only_dataclasses.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def wrap(cls):
127127

128128
def _process_class(cls: type[M], extra_fields=None, **kwargs):
129129
"""Transforms `cls` into a dataclass that supports kw_only fields."""
130-
if '__annotations__' not in cls.__dict__:
130+
if not hasattr(cls, "__annotations__"):
131131
cls.__annotations__ = {}
132132

133133
# The original __dataclass_fields__ dicts for all base classes. We will
@@ -174,7 +174,7 @@ def _process_class(cls: type[M], extra_fields=None, **kwargs):
174174
for base in reversed(cls.__mro__[1:]):
175175
if not dataclasses.is_dataclass(base):
176176
continue
177-
base_annotations = base.__dict__.get('__annotations__', {})
177+
base_annotations = inspect.get_annotations(base)
178178
base_dataclass_fields[base] = dict(
179179
getattr(base, '__dataclass_fields__', {})
180180
)
@@ -188,7 +188,8 @@ def _process_class(cls: type[M], extra_fields=None, **kwargs):
188188
del base.__dataclass_fields__[field_name]
189189

190190
# Remove any keyword-only fields from this class.
191-
cls_annotations = cls.__dict__['__annotations__']
191+
# Note: in Python 3.14+ cls.__dict__ does not contain __annotation__ key but __annotations_cache__
192+
cls_annotations = getattr(cls, "__annotations__")
192193
for name, annotation in list(cls_annotations.items()):
193194
value = getattr(cls, name, None)
194195
if (

flax/linen/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def _customized_dataclass_transform(cls, kw_only: bool):
10611061
3. Generate a hash function (if not provided by cls).
10621062
"""
10631063
# Check reserved attributes have expected type annotations.
1064-
annotations = dict(cls.__dict__.get('__annotations__', {}))
1064+
annotations = inspect.get_annotations(cls)
10651065
if annotations.get('parent', _ParentType) != _ParentType:
10661066
raise errors.ReservedModuleAttributeError(annotations)
10671067
if annotations.get('name', str) not in ('str', str, Optional[str]):

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ dependencies = [
1717
"jax>=0.7.1",
1818
"msgpack",
1919
"optax",
20-
"orbax-checkpoint",
21-
"tensorstore",
20+
"orbax-checkpoint; python_version<'3.14'", # temporary skip orbax-checkpoint for py3.14
21+
"tensorstore; python_version<'3.14'", # temporary skip tensorstore for py3.14
2222
"rich>=11.1",
2323
"typing_extensions>=4.2",
2424
"PyYAML>=5.4.1",

0 commit comments

Comments
 (0)