fix: LlamaTokenizerFast to AutoTokenizer at flash_llama.py (#619)

# What does this PR do?

A few tokenizer_config in huggingface use LlamaTokenizer, so I think I
would have selected `LlamaTokenizer` before.

For a few cases where you're using a llama structure but not a llama
tokenizer, why not make it to call the AutoTokenizer in exception
handling.

In the case of `decapoda-research/llama-7b-hf`, LLamaTokenizer is still
being used in config.json, so it should be called through`
LlamaTokenizer`.
Also, if an exception is thrown by LlamaTokenizer, it will cause
`LlamaTokenzierFast` to be called from AutoTokenizer.


Fixes # 560


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [x] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil
This commit is contained in:
Dong Shin 2023-08-14 21:20:18 +09:00 committed by GitHub
parent b5087c4f4e
commit a072660bf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -2,7 +2,8 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from transformers import AutoConfig, AutoTokenizer
from transformers.models.llama import LlamaTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
@ -44,7 +45,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = LlamaTokenizerFast.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",