fix(server): Check for device type correctly when determining initial padding (#16)
AFAIK there is no torch device type called "gpu".
This commit is contained in:
parent
611e21cb13
commit
686cc66717
|
@ -65,7 +65,7 @@ class CausalLMBatch:
|
|||
)
|
||||
all_logprobs.append(None)
|
||||
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
|
|
|
@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
|
|||
decoder_logprobs.append(None)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
|
|
Loading…
Reference in New Issue