From 6abec14a7eeb6e29a394557d64e2b527af1a89fb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 5 Jun 2023 16:09:41 +0200 Subject: [PATCH] feat(server): batch tokenization for flash causal lm (#411) --- .../models/flash_causal_lm.py | 18 ++++++++++++++---- server/text_generation_server/utils/hub.py | 4 +++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5ff951b3..a2ad2d5e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + position_ids = [] cu_seqlens = [0] max_seqlen = 0 @@ -106,13 +116,13 @@ class FlashCausalLMBatch(Batch): max_length = 0 # Parse batch - for i, r in enumerate(pb.requests): + for i, (r, tokenized_input) in enumerate( + zip(pb.requests, batch_tokenized_inputs) + ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenizer( - r.inputs, truncation=True, max_length=r.truncate - )["input_ids"] + tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 134ac7cd..2ed7673c 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -134,7 +134,7 @@ def download_weights( ) -> List[Path]: """Download the safetensors files from the hub""" - def download_file(filename, tries=5): + def download_file(filename, tries=5, backoff: int = 5): local_file = try_to_load_from_cache(model_id, revision, filename) if local_file is not None: logger.info(f"File {filename} already present in cache.") @@ -158,6 +158,8 @@ def download_weights( if i + 1 == tries: raise e logger.error(e) + logger.info(f"Retrying in {backoff} seconds") + time.sleep(backoff) logger.info(f"Retry {i + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher