Skip to content

Commit 5a2da6d

Browse files
author
Flax Team
committed
Copybara import of the project:
-- 6622efc by vfdev-5 <[email protected]>: Support for python 3.14 Adapted the way to access `__annotations__` object. Flax linen code relies on `cls.__dict__["__annotations__"]`, but in Python 3.14+ `cls.__dict__` does not contain `__annotation__` key anymore. Fixes #5027 PiperOrigin-RevId: 822202664
1 parent 57c2cc2 commit 5a2da6d

File tree

4 files changed

+7
-35
lines changed

4 files changed

+7
-35
lines changed

.github/workflows/flax_test.yml

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
runs-on: ubuntu-latest
6464
strategy:
6565
matrix:
66-
python-version: ['3.11', '3.12', '3.13', '3.14']
66+
python-version: ['3.11', '3.12', '3.13']
6767
steps:
6868
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
6969
- name: Set up Python ${{ matrix.python-version }}
@@ -159,30 +159,3 @@ jobs:
159159
"description": "'$status'",
160160
"context": "github-actions/Build"
161161
}'
162-
163-
# This is a temporary workflow to test flax on Python 3.14 and
164-
# skipping deps like tensorstore, tensorflow etc
165-
tests-python314:
166-
name: Run Tests on Python 3.14
167-
needs: [pre-commit, commit-count]
168-
runs-on: ubuntu-24.04-16core
169-
steps:
170-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
171-
- name: Setup uv
172-
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
173-
with:
174-
version: "0.9.2"
175-
python-version: "3.14"
176-
activate-environment: true
177-
enable-cache: true
178-
179-
- name: Install dependencies
180-
run: |
181-
uv pip install jax msgpack treescope rich typing_extensions PyYAML optax cloudpickle
182-
uv pip install pytest pytest-custom_exit_code pytest-xdist pytest-cov
183-
uv pip install -e . --no-deps
184-
- name: Test with pytest
185-
run: |
186-
export XLA_FLAGS='--xla_force_host_platform_device_count=4'
187-
find tests/ -name "*.py" | grep -vE 'checkpoint|integr|io|tensorboard' | xargs pytest -n auto
188-

flax/linen/kw_only_dataclasses.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def wrap(cls):
127127

128128
def _process_class(cls: type[M], extra_fields=None, **kwargs):
129129
"""Transforms `cls` into a dataclass that supports kw_only fields."""
130-
if not hasattr(cls, "__annotations__"):
130+
if '__annotations__' not in cls.__dict__:
131131
cls.__annotations__ = {}
132132

133133
# The original __dataclass_fields__ dicts for all base classes. We will
@@ -174,7 +174,7 @@ def _process_class(cls: type[M], extra_fields=None, **kwargs):
174174
for base in reversed(cls.__mro__[1:]):
175175
if not dataclasses.is_dataclass(base):
176176
continue
177-
base_annotations = inspect.get_annotations(base)
177+
base_annotations = base.__dict__.get('__annotations__', {})
178178
base_dataclass_fields[base] = dict(
179179
getattr(base, '__dataclass_fields__', {})
180180
)
@@ -188,8 +188,7 @@ def _process_class(cls: type[M], extra_fields=None, **kwargs):
188188
del base.__dataclass_fields__[field_name]
189189

190190
# Remove any keyword-only fields from this class.
191-
# Note: in Python 3.14+ cls.__dict__ does not contain __annotation__ key but __annotations_cache__
192-
cls_annotations = getattr(cls, "__annotations__")
191+
cls_annotations = cls.__dict__['__annotations__']
193192
for name, annotation in list(cls_annotations.items()):
194193
value = getattr(cls, name, None)
195194
if (

flax/linen/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def _customized_dataclass_transform(cls, kw_only: bool):
10611061
3. Generate a hash function (if not provided by cls).
10621062
"""
10631063
# Check reserved attributes have expected type annotations.
1064-
annotations = inspect.get_annotations(cls)
1064+
annotations = dict(cls.__dict__.get('__annotations__', {}))
10651065
if annotations.get('parent', _ParentType) != _ParentType:
10661066
raise errors.ReservedModuleAttributeError(annotations)
10671067
if annotations.get('name', str) not in ('str', str, Optional[str]):

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ dependencies = [
1717
"jax>=0.7.1",
1818
"msgpack",
1919
"optax",
20-
"orbax-checkpoint; python_version<'3.14'", # temporary skip orbax-checkpoint for py3.14
21-
"tensorstore; python_version<'3.14'", # temporary skip tensorstore for py3.14
20+
"orbax-checkpoint",
21+
"tensorstore",
2222
"rich>=11.1",
2323
"typing_extensions>=4.2",
2424
"PyYAML>=5.4.1",

0 commit comments

Comments
 (0)