fix: LlamaTokenizerFast to AutoTokenizer at flash_mistral.py (#1637)

# What does this PR do?

A few cases where you're using a mistral structure or mixtral structure
but not a llama tokenizer, why not make it to call the AutoTokenizer in
exception handling.

Similar PR #619

@Narsil
This commit is contained in:
SeongBeomLEE 2024-03-23 01:13:13 +09:00 committed by GitHub
parent 08e9181418
commit 66914f7b19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 8 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase, AutoTokenizer
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type
@ -317,6 +317,7 @@ class BaseFlashMistral(FlashCausalLM):
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
try:
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = LlamaTokenizerFast.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -324,6 +325,14 @@ class BaseFlashMistral(FlashCausalLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained( config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code