diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index f6339d7c..0ff07417 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -337,7 +337,7 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor + top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor ) -> Tuple[List[List[int]], List[List[float]]]: """Find the top n most likely tokens for a batch of generations.