From 66914f7b19ff55ea29114aa229c6b94ffc9e6a35 Mon Sep 17 00:00:00 2001 From: SeongBeomLEE <2712qwer@naver.com> Date: Sat, 23 Mar 2024 01:13:13 +0900 Subject: [PATCH] fix: LlamaTokenizerFast to AutoTokenizer at flash_mistral.py (#1637) # What does this PR do? A few cases where you're using a mistral structure or mixtral structure but not a llama tokenizer, why not make it to call the AutoTokenizer in exception handling. Similar PR #619 @Narsil --- .../models/flash_mistral.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8149c1b0..2e1055b2 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -6,7 +6,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, AutoTokenizer from transformers.models.llama import LlamaTokenizerFast from typing import Optional, Tuple, Type @@ -317,13 +317,22 @@ class BaseFlashMistral(FlashCausalLM): else: raise NotImplementedError("FlashMistral is only available on GPU") - tokenizer = LlamaTokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) + try: + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code