diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8149c1b0..2e1055b2 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -6,7 +6,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, AutoTokenizer from transformers.models.llama import LlamaTokenizerFast from typing import Optional, Tuple, Type @@ -317,13 +317,22 @@ class BaseFlashMistral(FlashCausalLM): else: raise NotImplementedError("FlashMistral is only available on GPU") - tokenizer = LlamaTokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) + try: + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code