diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 1d31e82..e4e4224 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -318,6 +318,25 @@ def _spyre_scaled_paged_compute_op( attn_kwargs["block_table"], ) + def __spyre_scaled_paged_validate_attn_kwargs_op( + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_value_states: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, + **attn_kwargs, + ): + __spyre_paged_validate_attn_kwargs_op( + input_ids, position_ids, past_key_value_states, **attn_kwargs + ) + + if past_key_value_states is not None: + for k, v in past_key_value_states: + assert isinstance(k, ScaledTensor) + assert isinstance(v, ScaledTensor) + + # assert that for each layer, the scales are per-sequence + assert k._scale.shape[0] == input_ids.shape[0] + assert v._scale.shape[0] == input_ids.shape[0] + register_attention_op( "spyre_paged_attn_fp8", _spyre_scaled_paged_store_op, @@ -325,5 +344,5 @@ def _spyre_scaled_paged_compute_op( is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None, compute_decode_op=_spyre_scaled_paged_compute_op, - validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, + validate_attn_kwargs_op=__spyre_scaled_paged_validate_attn_kwargs_op, )