fix(server): Check for device type correctly when determining initial padding (#16)

AFAIK there is no torch device type called "gpu".
This commit is contained in:
Nick Hill 2022-12-30 10:30:42 -08:00 committed by GitHub
parent 611e21cb13
commit 686cc66717
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -65,7 +65,7 @@ class CausalLMBatch:
)
all_logprobs.append(None)
pad_to_multiple_of = 8 if "gpu" in str(device) else None
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",

View File

@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
decoder_logprobs.append(None)
# Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",