100 lines
2.5 KiB
Python
100 lines
2.5 KiB
Python
from functools import total_ordering
|
|
import torch
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
|
|
|
|
|
class Batch(ABC):
|
|
@abstractmethod
|
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "Batch":
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def filter(self, request_ids: List[int]) -> "Batch":
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class GeneratedText:
|
|
text: str
|
|
generated_tokens: int
|
|
finish_reason: FinishReason
|
|
seed: Optional[int]
|
|
|
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
|
return generate_pb2.GeneratedText(
|
|
text=self.text,
|
|
generated_tokens=self.generated_tokens,
|
|
finish_reason=self.finish_reason,
|
|
seed=self.seed,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Tokens:
|
|
token_ids: List[int]
|
|
logprobs: List[float]
|
|
texts: List[str]
|
|
is_special: List[bool]
|
|
|
|
def to_pb(self) -> generate_pb2.Tokens:
|
|
return generate_pb2.Tokens(
|
|
ids=self.token_ids,
|
|
logprobs=self.logprobs,
|
|
texts=self.texts,
|
|
is_special=self.is_special,
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.token_ids)
|
|
|
|
|
|
@dataclass
|
|
class Generation:
|
|
request_id: int
|
|
prefill_tokens: Optional[Tokens]
|
|
tokens: Tokens
|
|
generated_text: Optional[GeneratedText]
|
|
# Optional for now, since it's not yet supported for every model.
|
|
top_tokens: Optional[List[Tokens]]
|
|
|
|
def to_pb(self) -> generate_pb2.Generation:
|
|
return generate_pb2.Generation(
|
|
request_id=self.request_id,
|
|
prefill_tokens=self.prefill_tokens.to_pb()
|
|
if self.prefill_tokens is not None
|
|
else None,
|
|
tokens=self.tokens.to_pb(),
|
|
generated_text=self.generated_text.to_pb()
|
|
if self.generated_text is not None
|
|
else None,
|
|
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
|
)
|