hf_text-generation-inference/server/text_generation_server/models/types.py

100 lines
2.5 KiB
Python

from functools import total_ordering
import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):
@abstractmethod
def to_pb(self) -> generate_pb2.CachedBatch:
raise NotImplementedError
@classmethod
@abstractmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
raise NotImplementedError
@abstractmethod
def filter(self, request_ids: List[int]) -> "Batch":
raise NotImplementedError
@classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError
@abstractmethod
def __len__(self):
raise NotImplementedError
@dataclass
class GeneratedText:
text: str
generated_tokens: int
finish_reason: FinishReason
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
text=self.text,
generated_tokens=self.generated_tokens,
finish_reason=self.finish_reason,
seed=self.seed,
)
@dataclass
class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[Tokens]
tokens: Tokens
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
)