diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index aeecf12..3ad3621 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -65,7 +65,7 @@ class CausalLMBatch: ) all_logprobs.append(None) - pad_to_multiple_of = 8 if "gpu" in str(device) else None + pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( inputs, return_tensors="pt", diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index fc80c60..4095db9 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -77,7 +77,7 @@ class Seq2SeqLMBatch: decoder_logprobs.append(None) # Tokenize batch - pad_to_multiple_of = 8 if "gpu" in str(device) else None + pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( inputs, return_tensors="pt",