fix: LlamaTokenizerFast to AutoTokenizer at flash_mistral.py (#1637)
# What does this PR do? A few cases where you're using a mistral structure or mixtral structure but not a llama tokenizer, why not make it to call the AutoTokenizer in exception handling. Similar PR #619 @Narsil
This commit is contained in:
parent
08e9181418
commit
66914f7b19
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase, AutoTokenizer
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizerFast
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
|
@ -317,13 +317,22 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
try:
|
||||||
model_id,
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
revision=revision,
|
model_id,
|
||||||
padding_side="left",
|
revision=revision,
|
||||||
truncation_side="left",
|
padding_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
truncation_side="left",
|
||||||
)
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
config = config_cls.from_pretrained(
|
config = config_cls.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
|
Loading…
Reference in New Issue