diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index f38f685970e..a7c7872ecfe 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -570,22 +570,6 @@ def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list .to(inv_freq.device) ) - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ): - # rotate half the sequence length - rot = cos.shape[-1] // 2 - q2 = torch.cat([-query[..., rot:], query[..., :rot]], dim=-1) - k2 = torch.cat([-key[..., rot:], key[..., :rot]], dim=-1) - - # apply the rotation - rotary_emb.apply_rotary(query, q2, cos, sin, query, q2, True) - rotary_emb.apply_rotary(key, k2, cos, sin, key, k2, True) - def _update_cos_sin_cache( self, dtype: torch.dtype, device: torch.device, seqlen: int ): @@ -614,7 +598,4 @@ def get_cos_sin( cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) - - cos = torch.cat([cos, cos], dim=-1) - sin = torch.cat([sin, sin], dim=-1) return cos, sin