hf_text-generation-inference/server/text_generation_server/layers/attention/common.py

53 lines
1.4 KiB
Python

from dataclasses import dataclass
import torch
from typing import Optional
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self