From 65e2f1624ef37441a6343a427720c31a13088e63 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 24 Feb 2023 17:20:00 +0100 Subject: [PATCH] fix(server): fix token_is_special (#87) --- server/text_generation/models/causal_lm.py | 2 +- server/text_generation/models/seq2seq_lm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index d15197d..30aff87 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -445,7 +445,7 @@ class CausalLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, - next_token_id_squeezed in self.all_special_ids, + next_token_id_squeezed.item() in self.all_special_ids, generated_text, ) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 3a4108a..3738d7a 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -509,7 +509,7 @@ class Seq2SeqLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, - next_token_id_squeezed in self.all_special_ids, + next_token_id_squeezed.item() in self.all_special_ids, generated_text, )