feat(server): optimize decode for sane tokenizers (#170)

This commit is contained in:
OlivierDehaene 2023-04-12 12:03:10 +02:00 committed by GitHub
parent 6f0f1d70f6
commit 5fa8ae041c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 32 deletions

4
benchmark/Cargo.lock generated
View File

@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "grpc-metadata" name = "grpc-metadata"
version = "0.4.1" version = "0.1.0"
dependencies = [ dependencies = [
"opentelemetry", "opentelemetry",
"tonic", "tonic",
@ -2140,7 +2140,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.4.3" version = "0.5.0"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",

View File

@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOM(CausalLM): class BLOOM(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch
@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM):
self.model = model.eval().to(dtype) self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
@staticmethod @staticmethod

View File

@ -291,7 +291,13 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -319,8 +325,7 @@ class CausalLM(Model):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property

View File

@ -212,7 +212,8 @@ class FlashCausalLM(Model):
model_cls: Type[PreTrainedModel], model_cls: Type[PreTrainedModel],
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize=False, quantize: bool = False,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -237,8 +238,7 @@ class FlashCausalLM(Model):
) )
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property

View File

@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM):
self.model = model.eval().to(device).to(dtype) self.model = model.eval().to(device).to(dtype)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
@staticmethod @staticmethod

View File

@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
decode_buffer: int = 3,
):
if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1")
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device self.device = device
self.decode_buffer = decode_buffer
@property @property
@abstractmethod @abstractmethod
@ -39,23 +48,37 @@ class Model(ABC):
) )
if token_offset is None: if token_offset is None:
token_offset = len(all_input_ids) - 3 token_offset = len(all_input_ids) - self.decode_buffer
# left token buffer
if self.decode_buffer > 1:
# Decode token_offset token minus last one and token_offset tokens # Decode token_offset token minus last one and token_offset tokens
results = self.tokenizer.batch_decode( raw_texts = self.tokenizer.batch_decode(
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]], [all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
skip_special_tokens=False, skip_special_tokens=False,
) )
# default offset is only the last token # default offset is only the last token
if offset is None: offset = len(raw_texts[0])
offset = len(results[0]) sequence_text = raw_texts[1]
else:
# Only decode the last token without using a token buffer
sequence_text = self.tokenizer.decode(
all_input_ids[-1], skip_special_tokens=False
)
# no offset in this case
offset = 0
else:
assert offset is not None
sequence_text = self.tokenizer.decode(
all_input_ids[token_offset:],
skip_special_tokens=False,
)
# get text # get text
text = results[1][offset:] token_text = sequence_text[offset:]
# if text is utf-8 # if text is utf-8
if text and text[-1] != "<EFBFBD>": if token_text and token_text[-1] != "<EFBFBD>":
return text, None, None return token_text, None, None
else: else:
return "", offset, token_offset return "", offset, token_offset

View File

@ -54,8 +54,7 @@ class SantaCoder(CausalLM):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=1
device=device,
) )
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:

View File

@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -354,8 +360,7 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
device=device,
) )
@property @property
@ -496,7 +501,7 @@ class Seq2SeqLM(Model):
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) output_text = self.decode(decoder_input_ids[-decoder_input_length:])
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):