fix(server): Only pad to multiple of 8 on GPUs
This commit is contained in:
parent
a2985036aa
commit
042180d88f
|
@ -71,8 +71,9 @@ class CausalLMBatch:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
||||||
).to(device)
|
).to(device)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
|
|
|
@ -83,8 +83,9 @@ class Seq2SeqLMBatch:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
|
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
|
||||||
).to(device)
|
).to(device)
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||||
|
|
Loading…
Reference in New Issue