fix(server): Only pad to multiple of 8 on GPUs

This commit is contained in:
OlivierDehaene 2022-12-08 19:37:37 +01:00
parent a2985036aa
commit 042180d88f
2 changed files with 4 additions and 2 deletions

View File

@ -71,8 +71,9 @@ class CausalLMBatch:
)
)
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
).to(device)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)

View File

@ -83,8 +83,9 @@ class Seq2SeqLMBatch:
)
# Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)