2022-11-04 07:22:47 -06:00
|
|
|
|
import torch
|
|
|
|
|
|
2022-11-03 09:07:54 -06:00
|
|
|
|
from abc import ABC, abstractmethod
|
2022-11-04 11:03:04 -06:00
|
|
|
|
from typing import List, Tuple, Optional, TypeVar, Type
|
2023-01-17 01:10:22 -07:00
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2023-03-07 10:52:22 -07:00
|
|
|
|
from text_generation_server.models.types import Batch, GeneratedText
|
2023-04-21 07:36:29 -06:00
|
|
|
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
B = TypeVar("B", bound=Batch)
|
|
|
|
|
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2022-11-03 09:07:54 -06:00
|
|
|
|
class Model(ABC):
|
2023-04-12 04:03:10 -06:00
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2023-04-21 07:36:29 -06:00
|
|
|
|
requires_padding: bool,
|
|
|
|
|
dtype: torch.dtype,
|
2023-04-12 04:03:10 -06:00
|
|
|
|
device: torch.device,
|
|
|
|
|
decode_buffer: int = 3,
|
|
|
|
|
):
|
|
|
|
|
if decode_buffer < 1:
|
|
|
|
|
raise ValueError("decode_buffer must be >= 1")
|
|
|
|
|
|
2022-11-04 07:22:47 -06:00
|
|
|
|
self.tokenizer = tokenizer
|
2023-02-24 07:55:57 -07:00
|
|
|
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
2023-04-21 07:36:29 -06:00
|
|
|
|
self.requires_padding = requires_padding
|
|
|
|
|
self.dtype = dtype
|
2022-11-04 07:22:47 -06:00
|
|
|
|
self.device = device
|
2023-04-12 04:03:10 -06:00
|
|
|
|
self.decode_buffer = decode_buffer
|
2022-11-04 07:22:47 -06:00
|
|
|
|
|
2023-04-21 07:36:29 -06:00
|
|
|
|
@property
|
|
|
|
|
def info(self) -> InfoResponse:
|
|
|
|
|
return InfoResponse(
|
|
|
|
|
requires_padding=self.requires_padding,
|
|
|
|
|
dtype=str(self.dtype),
|
|
|
|
|
device_type=self.device.type,
|
|
|
|
|
)
|
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
@property
|
2022-11-03 09:07:54 -06:00
|
|
|
|
@abstractmethod
|
2022-11-04 11:03:04 -06:00
|
|
|
|
def batch_type(self) -> Type[B]:
|
2022-11-03 09:07:54 -06:00
|
|
|
|
raise NotImplementedError
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
@abstractmethod
|
|
|
|
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
|
|
|
|
raise NotImplementedError
|
2023-03-06 05:22:58 -07:00
|
|
|
|
|
2023-04-11 08:38:22 -06:00
|
|
|
|
def decode_token(
|
|
|
|
|
self,
|
|
|
|
|
all_input_ids: List[int],
|
|
|
|
|
offset: Optional[int] = None,
|
|
|
|
|
token_offset: Optional[int] = None,
|
|
|
|
|
) -> Tuple[str, Optional[int], Optional[int]]:
|
2023-03-06 05:22:58 -07:00
|
|
|
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
2023-04-11 08:38:22 -06:00
|
|
|
|
if all_input_ids[-1] in self.all_special_ids:
|
|
|
|
|
return (
|
|
|
|
|
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
|
|
|
|
None,
|
|
|
|
|
None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if token_offset is None:
|
2023-04-12 04:03:10 -06:00
|
|
|
|
token_offset = len(all_input_ids) - self.decode_buffer
|
|
|
|
|
# left token buffer
|
|
|
|
|
if self.decode_buffer > 1:
|
|
|
|
|
# Decode token_offset token minus last one and token_offset tokens
|
|
|
|
|
raw_texts = self.tokenizer.batch_decode(
|
|
|
|
|
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
|
|
|
|
skip_special_tokens=False,
|
|
|
|
|
)
|
2023-04-11 08:38:22 -06:00
|
|
|
|
|
2023-04-12 04:03:10 -06:00
|
|
|
|
# default offset is only the last token
|
|
|
|
|
offset = len(raw_texts[0])
|
|
|
|
|
sequence_text = raw_texts[1]
|
|
|
|
|
else:
|
|
|
|
|
# Only decode the last token without using a token buffer
|
|
|
|
|
sequence_text = self.tokenizer.decode(
|
|
|
|
|
all_input_ids[-1], skip_special_tokens=False
|
|
|
|
|
)
|
|
|
|
|
# no offset in this case
|
|
|
|
|
offset = 0
|
|
|
|
|
else:
|
|
|
|
|
assert offset is not None
|
|
|
|
|
sequence_text = self.tokenizer.decode(
|
|
|
|
|
all_input_ids[token_offset:],
|
|
|
|
|
skip_special_tokens=False,
|
|
|
|
|
)
|
2023-04-11 08:38:22 -06:00
|
|
|
|
|
|
|
|
|
# get text
|
2023-04-12 04:03:10 -06:00
|
|
|
|
token_text = sequence_text[offset:]
|
2023-04-11 08:38:22 -06:00
|
|
|
|
|
|
|
|
|
# if text is utf-8
|
2023-04-12 04:03:10 -06:00
|
|
|
|
if token_text and token_text[-1] != "<EFBFBD>":
|
|
|
|
|
return token_text, None, None
|
2023-04-11 08:38:22 -06:00
|
|
|
|
else:
|
|
|
|
|
return "", offset, token_offset
|