Skip to content

Commit 49ef97c

Browse files
author
Flax Authors
committed
Merge pull request #5080 from google:improve-promote-dtype-support
PiperOrigin-RevId: 831005376
2 parents b1f573b + 3a630ff commit 49ef97c

File tree

6 files changed

+199
-47
lines changed

6 files changed

+199
-47
lines changed

flax/nnx/nn/activations.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
from jax.numpy import tanh
4343

4444
from flax import nnx
45-
from flax.typing import Array, Dtype
45+
from flax.nnx.nn import dtypes
46+
from flax.typing import Array, Dtype, PromoteDtypeFn
4647

4748

4849
__all__ = [
@@ -97,21 +98,40 @@ class PReLU(nnx.Module):
9798
9899
Args:
99100
negative_slope_init: the value to initialize the negative slope (default 0.01).
101+
dtype: the dtype of the computation (default: infer from input and params).
100102
param_dtype: the dtype passed to parameter initializers (default: float32).
103+
promote_dtype: function to promote the dtype of all input array arguments
104+
(including Variables accessed through ``self``) to the desired dtype. The
105+
function should accept a tuple of ``(inputs, negative_slope)`` and a ``dtype``
106+
keyword argument, and return a tuple of arrays with the promoted dtype.
101107
"""
102108
def __init__(
103109
self,
104110
negative_slope_init: float = 0.01,
105-
param_dtype: Dtype = jnp.float32
111+
*,
112+
dtype: Dtype | None = None,
113+
param_dtype: Dtype = jnp.float32,
114+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
106115
):
107116
self.negative_slope = nnx.Param(
108117
jnp.asarray(negative_slope_init, dtype=param_dtype)
109118
)
119+
self.dtype = dtype
110120
self.param_dtype = param_dtype
121+
self.promote_dtype = promote_dtype
111122

112123
def __call__(self, inputs: Array) -> Array:
124+
negative_slope = self.negative_slope[...]
125+
if self.dtype is not None:
126+
inputs, negative_slope = self.promote_dtype(
127+
(inputs, negative_slope), dtype=self.dtype
128+
)
129+
else:
130+
# Match Linen behavior: cast parameter to input dtype
131+
negative_slope = jnp.asarray(negative_slope, inputs.dtype)
132+
113133
return jnp.where(
114134
inputs >= 0,
115135
inputs,
116-
jnp.asarray(self.negative_slope[...], inputs.dtype) * inputs,
136+
negative_slope * inputs,
117137
)

flax/nnx/nn/attention.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,15 @@ class MultiHeadAttention(Module):
301301
num_heads, value_channels]``
302302
decode: whether to prepare and use an autoregressive cache.
303303
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
304+
qkv_promote_dtype: function to promote the dtype of all input array arguments
305+
(including Variables accessed through ``self``) to the desired dtype for the
306+
query, key, and value LinearGeneral submodules.
307+
out_promote_dtype: function to promote the dtype of all input array arguments
308+
(including Variables accessed through ``self``) to the desired dtype for the
309+
output LinearGeneral submodule.
310+
ln_promote_dtype: function to promote the dtype of all input array arguments
311+
(including Variables accessed through ``self``) to the desired dtype for the
312+
LayerNorm submodules (query_ln and key_ln) when normalize_qk=True.
304313
rngs: rng key.
305314
keep_rngs: whether to store the input rngs as attribute (i.e. `self.rngs = rngs`)
306315
(default: True). If rngs is stored, we should split the module as
@@ -330,6 +339,9 @@ def __init__(
330339
attention_fn: Callable[..., Array] = dot_product_attention,
331340
decode: bool | None = None,
332341
normalize_qk: bool = False,
342+
qkv_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
343+
out_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
344+
ln_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
333345
# Deprecated, will be removed.
334346
qkv_dot_general: DotGeneralT | None = None,
335347
out_dot_general: DotGeneralT | None = None,
@@ -359,6 +371,9 @@ def __init__(
359371
self.attention_fn = attention_fn
360372
self.decode = decode
361373
self.normalize_qk = normalize_qk
374+
self.qkv_promote_dtype = qkv_promote_dtype
375+
self.out_promote_dtype = out_promote_dtype
376+
self.ln_promote_dtype = ln_promote_dtype
362377
self.qkv_dot_general = qkv_dot_general
363378
self.out_dot_general = out_dot_general
364379
self.qkv_dot_general_cls = qkv_dot_general_cls
@@ -381,6 +396,7 @@ def __init__(
381396
bias_init=bias_init,
382397
use_bias=self.use_bias,
383398
precision=self.precision,
399+
promote_dtype=self.qkv_promote_dtype,
384400
dot_general=self.qkv_dot_general,
385401
dot_general_cls=self.qkv_dot_general_cls,
386402
)
@@ -400,13 +416,15 @@ def __init__(
400416
use_bias=False,
401417
dtype=self.dtype,
402418
param_dtype=self.param_dtype,
419+
promote_dtype=self.ln_promote_dtype,
403420
rngs=rngs,
404421
)
405422
self.key_ln = LayerNorm(
406423
self.head_dim,
407424
use_bias=False,
408425
dtype=self.dtype,
409426
param_dtype=self.param_dtype,
427+
promote_dtype=self.ln_promote_dtype,
410428
rngs=rngs,
411429
)
412430
else:
@@ -423,6 +441,7 @@ def __init__(
423441
dtype=self.dtype,
424442
param_dtype=self.param_dtype,
425443
precision=self.precision,
444+
promote_dtype=self.out_promote_dtype,
426445
dot_general=self.out_dot_general,
427446
dot_general_cls=self.out_dot_general_cls,
428447
rngs=rngs,

flax/nnx/nn/lora.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
from flax.nnx import rnglib, variablelib
1919
from flax.nnx.module import Module
20-
from flax.nnx.nn import initializers
20+
from flax.nnx.nn import initializers, dtypes
2121
from flax.nnx.nn.linear import Linear
22-
from flax.nnx.nn.dtypes import promote_dtype
23-
from flax.typing import Dtype, Initializer
22+
from flax.typing import Dtype, Initializer, PromoteDtypeFn
2423
import jax
2524
import jax.numpy as jnp
2625

@@ -75,6 +74,11 @@ class LoRA(Module):
7574
b_initializer: initializer function for the fan-out matrices. Default to
7675
`zero initializer`.
7776
lora_param_type: the type of the LoRA params.
77+
promote_dtype: function to promote the dtype of all input array arguments
78+
(including Variables accessed through ``self``) to the desired dtype. The
79+
function should accept a tuple of ``(inputs, lora_a, lora_b)`` and a ``dtype``
80+
keyword argument, and return a tuple of arrays with the promoted dtype.
81+
rngs: rng key.
7882
"""
7983

8084
def __init__(
@@ -89,6 +93,7 @@ def __init__(
8993
a_initializer: Initializer = default_a_initializer,
9094
b_initializer: Initializer = default_b_initializer,
9195
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
96+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
9297
rngs: rnglib.Rngs,
9398
):
9499
self.in_features = in_features
@@ -97,6 +102,7 @@ def __init__(
97102
self.param_dtype = param_dtype
98103
self.lora_param_type = lora_param_type
99104
self.base_module = base_module
105+
self.promote_dtype = promote_dtype
100106

101107
self.lora_a = lora_param_type(
102108
a_initializer(rngs.params(), (in_features, lora_rank), param_dtype)
@@ -106,7 +112,7 @@ def __init__(
106112
)
107113

108114
def __call__(self, x: jax.Array):
109-
x, lora_a, lora_b = promote_dtype(
115+
x, lora_a, lora_b = self.promote_dtype(
110116
(x, self.lora_a[...], self.lora_b[...]), dtype=self.dtype
111117
)
112118
out = x @ lora_a @ lora_b
@@ -154,33 +160,36 @@ class LoRALinear(Linear):
154160
b_initializer: initializer function for the fan-out matrices. Default to
155161
`zero initializer`.
156162
lora_param_type: the type of the LoRA params.
163+
lora_promote_dtype: function to promote the dtype for the LoRA submodule.
157164
"""
158165

159166
def __init__(
160-
self,
161-
in_features: int,
162-
out_features: int,
163-
*,
164-
lora_rank: int,
165-
lora_dtype: tp.Optional[Dtype] = None,
166-
lora_param_dtype: Dtype = jnp.float32,
167-
a_initializer: Initializer = default_a_initializer,
168-
b_initializer: Initializer = default_b_initializer,
169-
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
170-
rngs: rnglib.Rngs,
171-
**kwargs,
167+
self,
168+
in_features: int,
169+
out_features: int,
170+
*,
171+
lora_rank: int,
172+
lora_dtype: tp.Optional[Dtype] = None,
173+
lora_param_dtype: Dtype = jnp.float32,
174+
a_initializer: Initializer = default_a_initializer,
175+
b_initializer: Initializer = default_b_initializer,
176+
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
177+
lora_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
178+
rngs: rnglib.Rngs,
179+
**kwargs,
172180
):
173181
super().__init__(in_features, out_features, rngs=rngs, **kwargs)
174182
self.lora = LoRA(
175-
in_features,
176-
lora_rank,
177-
out_features,
178-
dtype=lora_dtype,
179-
param_dtype=lora_param_dtype,
180-
a_initializer=a_initializer,
181-
b_initializer=b_initializer,
182-
lora_param_type=lora_param_type,
183-
rngs=rngs,
183+
in_features,
184+
lora_rank,
185+
out_features,
186+
dtype=lora_dtype,
187+
param_dtype=lora_param_dtype,
188+
a_initializer=a_initializer,
189+
b_initializer=b_initializer,
190+
lora_param_type=lora_param_type,
191+
promote_dtype=lora_promote_dtype,
192+
rngs=rngs,
184193
)
185194

186195
def __call__(self, x: jax.Array):

0 commit comments

Comments
 (0)