feat(server): batch tokenization for flash causal lm (#411)

This commit is contained in:
OlivierDehaene 2023-06-05 16:09:41 +02:00 committed by GitHub
parent 895c5f1562
commit 6abec14a7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 5 deletions

View File

@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = [] position_ids = []
cu_seqlens = [0] cu_seqlens = [0]
max_seqlen = 0 max_seqlen = 0
@ -106,13 +116,13 @@ class FlashCausalLMBatch(Batch):
max_length = 0 max_length = 0
# Parse batch # Parse batch
for i, r in enumerate(pb.requests): for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenizer( tokenized_input = tokenized_input[-r.truncate :]
r.inputs, truncation=True, max_length=r.truncate
)["input_ids"]
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)

View File

@ -134,7 +134,7 @@ def download_weights(
) -> List[Path]: ) -> List[Path]:
"""Download the safetensors files from the hub""" """Download the safetensors files from the hub"""
def download_file(filename, tries=5): def download_file(filename, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, filename) local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None: if local_file is not None:
logger.info(f"File {filename} already present in cache.") logger.info(f"File {filename} already present in cache.")
@ -158,6 +158,8 @@ def download_weights(
if i + 1 == tries: if i + 1 == tries:
raise e raise e
logger.error(e) logger.error(e)
logger.info(f"Retrying in {backoff} seconds")
time.sleep(backoff)
logger.info(f"Retry {i + 1}/{tries - 1}") logger.info(f"Retry {i + 1}/{tries - 1}")
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher