@@ -167,7 +167,6 @@ def __call__(
167167 if cache is None :
168168 cache = [None ] * (2 * len (self .layers ))
169169
170- offset = cache [0 ].offset if cache [0 ] is not None else 0
171170 mask = create_attention_mask (h , cache [0 ])
172171 window_mask = create_attention_mask (
173172 h , cache [len (self .layers )], window_size = self .loop_window_size
@@ -176,6 +175,7 @@ def __call__(
176175 loop1_kv = []
177176 for layer , c in zip (self .layers , cache ):
178177 h_norm = layer .input_layernorm (h )
178+ offset = c .offset if c is not None else 0
179179 q1 , k1 , v1 = layer .self_attn .get_qkv (h_norm , offset )
180180
181181 if c is not None :
@@ -192,6 +192,7 @@ def __call__(
192192 self .layers , self .gate_projections , cache [len (self .layers ) :], loop1_kv
193193 ):
194194 h_norm = layer .input_layernorm (h )
195+ offset = c .offset if c is not None else 0
195196 q2 , k2 , v2 = layer .self_attn .get_qkv (h_norm , offset )
196197 gate = gate_proj (q2 )
197198 attn_global = layer .self_attn .attention (q2 , k1 , v1 , mask , cache = c )
0 commit comments