fix(server): Check for device type correctly when determining initial padding (#16)
AFAIK there is no torch device type called "gpu".
This commit is contained in:
parent
611e21cb13
commit
686cc66717
|
@ -65,7 +65,7 @@ class CausalLMBatch:
|
||||||
)
|
)
|
||||||
all_logprobs.append(None)
|
all_logprobs.append(None)
|
||||||
|
|
||||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
|
|
@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
|
||||||
decoder_logprobs.append(None)
|
decoder_logprobs.append(None)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
|
Loading…
Reference in New Issue