fix(server): Only pad to multiple of 8 on GPUs
This commit is contained in:
parent
a2985036aa
commit
042180d88f
|
@ -71,8 +71,9 @@ class CausalLMBatch:
|
|||
)
|
||||
)
|
||||
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
||||
).to(device)
|
||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||
|
||||
|
|
|
@ -83,8 +83,9 @@ class Seq2SeqLMBatch:
|
|||
)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
||||
).to(device)
|
||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||
|
|
Loading…
Reference in New Issue