feat(server): optimize decode for sane tokenizers (#170)
This commit is contained in:
parent
6f0f1d70f6
commit
5fa8ae041c
|
@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
|||
|
||||
[[package]]
|
||||
name = "grpc-metadata"
|
||||
version = "0.4.1"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"opentelemetry",
|
||||
"tonic",
|
||||
|
@ -2140,7 +2140,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "0.4.3"
|
||||
version = "0.5.0"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
|
|
|
@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||
|
||||
|
||||
class BLOOM(CausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
super(BLOOM, self).__init__(
|
||||
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM):
|
|||
self.model = model.eval().to(dtype)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -291,7 +291,13 @@ class CausalLMBatch(Batch):
|
|||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
|
@ -319,8 +325,7 @@ class CausalLM(Model):
|
|||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -212,7 +212,8 @@ class FlashCausalLM(Model):
|
|||
model_cls: Type[PreTrainedModel],
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize=False,
|
||||
quantize: bool = False,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -237,8 +238,7 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
self.model = model.eval().to(device).to(dtype)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch)
|
|||
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if decode_buffer < 1:
|
||||
raise ValueError("decode_buffer must be >= 1")
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
self.device = device
|
||||
self.decode_buffer = decode_buffer
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -39,23 +48,37 @@ class Model(ABC):
|
|||
)
|
||||
|
||||
if token_offset is None:
|
||||
token_offset = len(all_input_ids) - 3
|
||||
token_offset = len(all_input_ids) - self.decode_buffer
|
||||
# left token buffer
|
||||
if self.decode_buffer > 1:
|
||||
# Decode token_offset token minus last one and token_offset tokens
|
||||
raw_texts = self.tokenizer.batch_decode(
|
||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Decode token_offset token minus last one and token_offset tokens
|
||||
results = self.tokenizer.batch_decode(
|
||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# default offset is only the last token
|
||||
if offset is None:
|
||||
offset = len(results[0])
|
||||
# default offset is only the last token
|
||||
offset = len(raw_texts[0])
|
||||
sequence_text = raw_texts[1]
|
||||
else:
|
||||
# Only decode the last token without using a token buffer
|
||||
sequence_text = self.tokenizer.decode(
|
||||
all_input_ids[-1], skip_special_tokens=False
|
||||
)
|
||||
# no offset in this case
|
||||
offset = 0
|
||||
else:
|
||||
assert offset is not None
|
||||
sequence_text = self.tokenizer.decode(
|
||||
all_input_ids[token_offset:],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# get text
|
||||
text = results[1][offset:]
|
||||
token_text = sequence_text[offset:]
|
||||
|
||||
# if text is utf-8
|
||||
if text and text[-1] != "<EFBFBD>":
|
||||
return text, None, None
|
||||
if token_text and token_text[-1] != "<EFBFBD>":
|
||||
return token_text, None, None
|
||||
else:
|
||||
return "", offset, token_offset
|
||||
|
|
|
@ -54,8 +54,7 @@ class SantaCoder(CausalLM):
|
|||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
|
|
|
@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
|
||||
class Seq2SeqLM(Model):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
|
@ -354,8 +360,7 @@ class Seq2SeqLM(Model):
|
|||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -496,7 +501,7 @@ class Seq2SeqLM(Model):
|
|||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||
output_text = self.decode(decoder_input_ids[-decoder_input_length:])
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
|
|
Loading…
Reference in New Issue