diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 796fbd47..9a7dfaee 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -3,7 +3,6 @@ import torch.distributed from opentelemetry import trace from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers.models.llama import LlamaTokenizer from typing import Optional from text_generation_server.models import FlashCausalLM @@ -41,22 +40,13 @@ class FlashLlama(FlashCausalLM): else: raise NotImplementedError("FlashLlama is only available on GPU") - try: - tokenizer = LlamaTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - except Exception: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) try: generation_config = GenerationConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code