85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
import torch
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Tuple, Optional, TypeVar, Type
|
||
from transformers import PreTrainedTokenizerBase
|
||
|
||
from text_generation_server.models.types import Batch, GeneratedText
|
||
|
||
B = TypeVar("B", bound=Batch)
|
||
|
||
|
||
class Model(ABC):
|
||
def __init__(
|
||
self,
|
||
tokenizer: PreTrainedTokenizerBase,
|
||
device: torch.device,
|
||
decode_buffer: int = 3,
|
||
):
|
||
if decode_buffer < 1:
|
||
raise ValueError("decode_buffer must be >= 1")
|
||
|
||
self.tokenizer = tokenizer
|
||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||
self.device = device
|
||
self.decode_buffer = decode_buffer
|
||
|
||
@property
|
||
@abstractmethod
|
||
def batch_type(self) -> Type[B]:
|
||
raise NotImplementedError
|
||
|
||
@abstractmethod
|
||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||
raise NotImplementedError
|
||
|
||
def decode_token(
|
||
self,
|
||
all_input_ids: List[int],
|
||
offset: Optional[int] = None,
|
||
token_offset: Optional[int] = None,
|
||
) -> Tuple[str, Optional[int], Optional[int]]:
|
||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||
if all_input_ids[-1] in self.all_special_ids:
|
||
return (
|
||
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
||
None,
|
||
None,
|
||
)
|
||
|
||
if token_offset is None:
|
||
token_offset = len(all_input_ids) - self.decode_buffer
|
||
# left token buffer
|
||
if self.decode_buffer > 1:
|
||
# Decode token_offset token minus last one and token_offset tokens
|
||
raw_texts = self.tokenizer.batch_decode(
|
||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||
skip_special_tokens=False,
|
||
)
|
||
|
||
# default offset is only the last token
|
||
offset = len(raw_texts[0])
|
||
sequence_text = raw_texts[1]
|
||
else:
|
||
# Only decode the last token without using a token buffer
|
||
sequence_text = self.tokenizer.decode(
|
||
all_input_ids[-1], skip_special_tokens=False
|
||
)
|
||
# no offset in this case
|
||
offset = 0
|
||
else:
|
||
assert offset is not None
|
||
sequence_text = self.tokenizer.decode(
|
||
all_input_ids[token_offset:],
|
||
skip_special_tokens=False,
|
||
)
|
||
|
||
# get text
|
||
token_text = sequence_text[offset:]
|
||
|
||
# if text is utf-8
|
||
if token_text and token_text[-1] != "<EFBFBD>":
|
||
return token_text, None, None
|
||
else:
|
||
return "", offset, token_offset
|