feat(server): optimize decode for sane tokenizers (#170)
This commit is contained in:
parent
6f0f1d70f6
commit
5fa8ae041c
|
@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpc-metadata"
|
name = "grpc-metadata"
|
||||||
version = "0.4.1"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"tonic",
|
"tonic",
|
||||||
|
@ -2140,7 +2140,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.4.3"
|
version = "0.5.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
|
|
|
@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||||
|
|
||||||
|
|
||||||
class BLOOM(CausalLM):
|
class BLOOM(CausalLM):
|
||||||
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
|
super(BLOOM, self).__init__(
|
||||||
|
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return BloomCausalLMBatch
|
return BloomCausalLMBatch
|
||||||
|
@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM):
|
||||||
self.model = model.eval().to(dtype)
|
self.model = model.eval().to(dtype)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -291,7 +291,13 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: bool = False,
|
||||||
|
decode_buffer: int = 3,
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -319,8 +325,7 @@ class CausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -212,7 +212,8 @@ class FlashCausalLM(Model):
|
||||||
model_cls: Type[PreTrainedModel],
|
model_cls: Type[PreTrainedModel],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize=False,
|
quantize: bool = False,
|
||||||
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -237,8 +238,7 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
self.model = model.eval().to(device).to(dtype)
|
self.model = model.eval().to(device).to(dtype)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
|
decode_buffer: int = 3,
|
||||||
|
):
|
||||||
|
if decode_buffer < 1:
|
||||||
|
raise ValueError("decode_buffer must be >= 1")
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.decode_buffer = decode_buffer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -39,23 +48,37 @@ class Model(ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
if token_offset is None:
|
if token_offset is None:
|
||||||
token_offset = len(all_input_ids) - 3
|
token_offset = len(all_input_ids) - self.decode_buffer
|
||||||
|
# left token buffer
|
||||||
|
if self.decode_buffer > 1:
|
||||||
|
# Decode token_offset token minus last one and token_offset tokens
|
||||||
|
raw_texts = self.tokenizer.batch_decode(
|
||||||
|
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Decode token_offset token minus last one and token_offset tokens
|
# default offset is only the last token
|
||||||
results = self.tokenizer.batch_decode(
|
offset = len(raw_texts[0])
|
||||||
[all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
|
sequence_text = raw_texts[1]
|
||||||
skip_special_tokens=False,
|
else:
|
||||||
)
|
# Only decode the last token without using a token buffer
|
||||||
|
sequence_text = self.tokenizer.decode(
|
||||||
# default offset is only the last token
|
all_input_ids[-1], skip_special_tokens=False
|
||||||
if offset is None:
|
)
|
||||||
offset = len(results[0])
|
# no offset in this case
|
||||||
|
offset = 0
|
||||||
|
else:
|
||||||
|
assert offset is not None
|
||||||
|
sequence_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[token_offset:],
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
# get text
|
# get text
|
||||||
text = results[1][offset:]
|
token_text = sequence_text[offset:]
|
||||||
|
|
||||||
# if text is utf-8
|
# if text is utf-8
|
||||||
if text and text[-1] != "<EFBFBD>":
|
if token_text and token_text[-1] != "<EFBFBD>":
|
||||||
return text, None, None
|
return token_text, None, None
|
||||||
else:
|
else:
|
||||||
return "", offset, token_offset
|
return "", offset, token_offset
|
||||||
|
|
|
@ -54,8 +54,7 @@ class SantaCoder(CausalLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
|
|
@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch):
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqLM(Model):
|
class Seq2SeqLM(Model):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: bool = False,
|
||||||
|
decode_buffer: int = 3,
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -354,8 +360,7 @@ class Seq2SeqLM(Model):
|
||||||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||||
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -496,7 +501,7 @@ class Seq2SeqLM(Model):
|
||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
output_text = self.decode(decoder_input_ids[-decoder_input_length:])
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
|
Loading…
Reference in New Issue