94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
import inspect
|
||
import torch
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Tuple, Optional, TypeVar, Type
|
||
from transformers import PreTrainedTokenizerBase
|
||
|
||
from text_generation_server.models.types import Batch, GeneratedText
|
||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||
|
||
B = TypeVar("B", bound=Batch)
|
||
|
||
|
||
class Model(ABC):
|
||
def __init__(
|
||
self,
|
||
model: torch.nn.Module,
|
||
tokenizer: PreTrainedTokenizerBase,
|
||
requires_padding: bool,
|
||
dtype: torch.dtype,
|
||
device: torch.device,
|
||
rank: int = 0,
|
||
world_size: int = 1,
|
||
):
|
||
self.model = model.eval()
|
||
self.tokenizer = tokenizer
|
||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||
self.requires_padding = requires_padding
|
||
self.dtype = dtype
|
||
self.device = device
|
||
self.rank = rank
|
||
self.world_size = world_size
|
||
|
||
self.has_position_ids = (
|
||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||
is not None
|
||
)
|
||
|
||
self.check_initialized()
|
||
|
||
@property
|
||
def info(self) -> InfoResponse:
|
||
return InfoResponse(
|
||
requires_padding=self.requires_padding,
|
||
dtype=str(self.dtype),
|
||
device_type=self.device.type,
|
||
)
|
||
|
||
@property
|
||
@abstractmethod
|
||
def batch_type(self) -> Type[B]:
|
||
raise NotImplementedError
|
||
|
||
@abstractmethod
|
||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||
raise NotImplementedError
|
||
|
||
def decode_token(
|
||
self,
|
||
all_input_ids: List[int],
|
||
prefix_offset: int = 0,
|
||
read_offset: int = 0,
|
||
) -> Tuple[str, int, int]:
|
||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||
|
||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||
# which decide to add a space or not depending on the surrounding ids.
|
||
prefix_text = self.tokenizer.decode(
|
||
all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
|
||
)
|
||
new_text = self.tokenizer.decode(
|
||
all_input_ids[prefix_offset:], skip_special_tokens=False
|
||
)
|
||
|
||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||
# from byte fallback tokenization.
|
||
# If it's in the middle, it's probably a real invalid id generated
|
||
# by the model
|
||
new_text = new_text[len(prefix_text) :]
|
||
return new_text, read_offset, len(all_input_ids)
|
||
else:
|
||
return "", prefix_offset, read_offset
|
||
|
||
def check_initialized(self):
|
||
uninitialized_parameters = []
|
||
for n, p in self.model.named_parameters():
|
||
if p.data.device == torch.device("meta"):
|
||
uninitialized_parameters.append(n)
|
||
if uninitialized_parameters:
|
||
raise RuntimeError(
|
||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||
)
|