1717
1818from flax .nnx import rnglib , variablelib
1919from flax .nnx .module import Module
20- from flax .nnx .nn import initializers
20+ from flax .nnx .nn import initializers , dtypes
2121from flax .nnx .nn .linear import Linear
22- from flax .nnx .nn .dtypes import promote_dtype
23- from flax .typing import Dtype , Initializer
22+ from flax .typing import Dtype , Initializer , PromoteDtypeFn
2423import jax
2524import jax .numpy as jnp
2625
@@ -75,6 +74,11 @@ class LoRA(Module):
7574 b_initializer: initializer function for the fan-out matrices. Default to
7675 `zero initializer`.
7776 lora_param_type: the type of the LoRA params.
77+ promote_dtype: function to promote the dtype of all input array arguments
78+ (including Variables accessed through ``self``) to the desired dtype. The
79+ function should accept a tuple of ``(inputs, lora_a, lora_b)`` and a ``dtype``
80+ keyword argument, and return a tuple of arrays with the promoted dtype.
81+ rngs: rng key.
7882 """
7983
8084 def __init__ (
@@ -89,6 +93,7 @@ def __init__(
8993 a_initializer : Initializer = default_a_initializer ,
9094 b_initializer : Initializer = default_b_initializer ,
9195 lora_param_type : tp .Type [variablelib .Variable ] = LoRAParam ,
96+ promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
9297 rngs : rnglib .Rngs ,
9398 ):
9499 self .in_features = in_features
@@ -97,6 +102,7 @@ def __init__(
97102 self .param_dtype = param_dtype
98103 self .lora_param_type = lora_param_type
99104 self .base_module = base_module
105+ self .promote_dtype = promote_dtype
100106
101107 self .lora_a = lora_param_type (
102108 a_initializer (rngs .params (), (in_features , lora_rank ), param_dtype )
@@ -106,7 +112,7 @@ def __init__(
106112 )
107113
108114 def __call__ (self , x : jax .Array ):
109- x , lora_a , lora_b = promote_dtype (
115+ x , lora_a , lora_b = self . promote_dtype (
110116 (x , self .lora_a [...], self .lora_b [...]), dtype = self .dtype
111117 )
112118 out = x @ lora_a @ lora_b
@@ -154,33 +160,36 @@ class LoRALinear(Linear):
154160 b_initializer: initializer function for the fan-out matrices. Default to
155161 `zero initializer`.
156162 lora_param_type: the type of the LoRA params.
163+ lora_promote_dtype: function to promote the dtype for the LoRA submodule.
157164 """
158165
159166 def __init__ (
160- self ,
161- in_features : int ,
162- out_features : int ,
163- * ,
164- lora_rank : int ,
165- lora_dtype : tp .Optional [Dtype ] = None ,
166- lora_param_dtype : Dtype = jnp .float32 ,
167- a_initializer : Initializer = default_a_initializer ,
168- b_initializer : Initializer = default_b_initializer ,
169- lora_param_type : tp .Type [variablelib .Variable ] = LoRAParam ,
170- rngs : rnglib .Rngs ,
171- ** kwargs ,
167+ self ,
168+ in_features : int ,
169+ out_features : int ,
170+ * ,
171+ lora_rank : int ,
172+ lora_dtype : tp .Optional [Dtype ] = None ,
173+ lora_param_dtype : Dtype = jnp .float32 ,
174+ a_initializer : Initializer = default_a_initializer ,
175+ b_initializer : Initializer = default_b_initializer ,
176+ lora_param_type : tp .Type [variablelib .Variable ] = LoRAParam ,
177+ lora_promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
178+ rngs : rnglib .Rngs ,
179+ ** kwargs ,
172180 ):
173181 super ().__init__ (in_features , out_features , rngs = rngs , ** kwargs )
174182 self .lora = LoRA (
175- in_features ,
176- lora_rank ,
177- out_features ,
178- dtype = lora_dtype ,
179- param_dtype = lora_param_dtype ,
180- a_initializer = a_initializer ,
181- b_initializer = b_initializer ,
182- lora_param_type = lora_param_type ,
183- rngs = rngs ,
183+ in_features ,
184+ lora_rank ,
185+ out_features ,
186+ dtype = lora_dtype ,
187+ param_dtype = lora_param_dtype ,
188+ a_initializer = a_initializer ,
189+ b_initializer = b_initializer ,
190+ lora_param_type = lora_param_type ,
191+ promote_dtype = lora_promote_dtype ,
192+ rngs = rngs ,
184193 )
185194
186195 def __call__ (self , x : jax .Array ):
0 commit comments