feat(server): batch tokenization for flash causal lm (#411)
This commit is contained in:
parent
895c5f1562
commit
6abec14a7e
|
@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch):
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
|
batch_inputs = []
|
||||||
|
max_truncation = 0
|
||||||
|
for r in pb.requests:
|
||||||
|
batch_inputs.append(r.inputs)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlens = [0]
|
cu_seqlens = [0]
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
@ -106,13 +116,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_length = 0
|
max_length = 0
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, r in enumerate(pb.requests):
|
for i, (r, tokenized_input) in enumerate(
|
||||||
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
|
):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenizer(
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
r.inputs, truncation=True, max_length=r.truncate
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
|
|
|
@ -134,7 +134,7 @@ def download_weights(
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
|
|
||||||
def download_file(filename, tries=5):
|
def download_file(filename, tries=5, backoff: int = 5):
|
||||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||||
if local_file is not None:
|
if local_file is not None:
|
||||||
logger.info(f"File {filename} already present in cache.")
|
logger.info(f"File {filename} already present in cache.")
|
||||||
|
@ -158,6 +158,8 @@ def download_weights(
|
||||||
if i + 1 == tries:
|
if i + 1 == tries:
|
||||||
raise e
|
raise e
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
logger.info(f"Retrying in {backoff} seconds")
|
||||||
|
time.sleep(backoff)
|
||||||
logger.info(f"Retry {i + 1}/{tries - 1}")
|
logger.info(f"Retry {i + 1}/{tries - 1}")
|
||||||
|
|
||||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
|
Loading…
Reference in New Issue