2024-07-01 15:28:00 -06:00
|
|
|
from dataclasses import dataclass
|
2024-09-30 02:54:32 -06:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
2024-08-09 08:41:17 -06:00
|
|
|
from text_generation_server.models.globals import ATTENTION
|
2024-07-01 15:28:00 -06:00
|
|
|
import torch
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
2024-08-09 08:41:17 -06:00
|
|
|
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
2024-07-01 15:28:00 -06:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Seqlen:
|
|
|
|
input_lengths: torch.Tensor
|
2024-08-29 08:29:01 -06:00
|
|
|
prefix_lengths: torch.Tensor
|
2024-07-01 15:28:00 -06:00
|
|
|
cu_seqlen_q: Optional[torch.Tensor]
|
|
|
|
cu_seqlen_k: Optional[torch.Tensor]
|
2024-08-29 08:29:01 -06:00
|
|
|
max_q: int
|
|
|
|
max_k: int
|
2024-07-01 15:28:00 -06:00
|
|
|
|
2024-08-29 08:29:01 -06:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_lengths,
|
|
|
|
prefix_lengths,
|
|
|
|
cu_seqlen_q=None,
|
|
|
|
max_q=None,
|
|
|
|
max_k=None,
|
|
|
|
):
|
2024-07-01 15:28:00 -06:00
|
|
|
self.input_lengths = input_lengths
|
2024-08-29 08:29:01 -06:00
|
|
|
self.prefix_lengths = prefix_lengths
|
2024-07-01 15:28:00 -06:00
|
|
|
device = self.input_lengths.device
|
|
|
|
shape = self.input_lengths.shape
|
2024-08-29 08:29:01 -06:00
|
|
|
if cu_seqlen_q is None:
|
|
|
|
cu_seqlen_q = torch.arange(
|
|
|
|
shape[0] + 1,
|
|
|
|
device=device,
|
|
|
|
dtype=torch.int32,
|
|
|
|
)
|
|
|
|
max_q = 1
|
|
|
|
else:
|
|
|
|
assert max_q is not None
|
|
|
|
assert max_k is not None
|
2024-07-01 15:28:00 -06:00
|
|
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
2024-08-29 08:29:01 -06:00
|
|
|
|
2024-07-01 15:28:00 -06:00
|
|
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
|
|
|
# Although FA2 might not want the clamping
|
|
|
|
# cu_seqlen_k[0] = 0
|
2024-08-29 08:29:01 -06:00
|
|
|
total = self.input_lengths + self.prefix_lengths
|
|
|
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
2024-07-01 15:28:00 -06:00
|
|
|
|
|
|
|
self.cu_seqlen_q = cu_seqlen_q
|
|
|
|
self.cu_seqlen_k = cu_seqlen_k
|
2024-08-29 08:29:01 -06:00
|
|
|
self.max_q = max_q
|
|
|
|
self.max_k = max_k
|
2024-07-01 15:28:00 -06:00
|
|
|
|
|
|
|
def clamp(self, max):
|
|
|
|
# Flash decoding doesn't need to clamp
|
|
|
|
return self
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Seqlen:
|
|
|
|
input_lengths: torch.Tensor
|
2024-08-29 08:29:01 -06:00
|
|
|
prefix_lengths: torch.Tensor
|
|
|
|
cu_seqlen_q: torch.Tensor
|
|
|
|
max_q: int
|
|
|
|
max_k: int
|
2024-07-01 15:28:00 -06:00
|
|
|
|
|
|
|
def clamp(self, max):
|
2024-09-30 02:54:32 -06:00
|
|
|
if SYSTEM == "rocm":
|
|
|
|
return self
|
2024-10-11 13:47:03 -06:00
|
|
|
self.input_lengths = torch.clamp(self.input_lengths, max=max)
|
|
|
|
return self
|