fix: LlamaTokenizerFast to AutoTokenizer at flash_llama.py (#619)
# What does this PR do? A few tokenizer_config in huggingface use LlamaTokenizer, so I think I would have selected `LlamaTokenizer` before. For a few cases where you're using a llama structure but not a llama tokenizer, why not make it to call the AutoTokenizer in exception handling. In the case of `decapoda-research/llama-7b-hf`, LLamaTokenizer is still being used in config.json, so it should be called through` LlamaTokenizer`. Also, if an exception is thrown by LlamaTokenizer, it will cause `LlamaTokenzierFast` to be called from AutoTokenizer. Fixes # 560 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil
This commit is contained in:
parent
b5087c4f4e
commit
a072660bf5
|
@ -2,7 +2,8 @@ import torch
|
|||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
|
@ -44,7 +45,7 @@ class FlashLlama(FlashCausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
|
|
Loading…
Reference in New Issue