diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index d15197d..30aff87 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -445,7 +445,7 @@ class CausalLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, - next_token_id_squeezed in self.all_special_ids, + next_token_id_squeezed.item() in self.all_special_ids, generated_text, ) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 3a4108a..3738d7a 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -509,7 +509,7 @@ class Seq2SeqLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, - next_token_id_squeezed in self.all_special_ids, + next_token_id_squeezed.item() in self.all_special_ids, generated_text, )