Fixing cohere tokenizer. (#1697)

This commit is contained in:
Nicolas Patry 2024-04-05 16:44:19 +02:00 committed by GitHub
parent 5062fda4ff
commit f9958ee191
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers.models.llama import LlamaTokenizerFast
from transformers import AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
@ -36,7 +36,7 @@ class FlashCohere(FlashCausalLM):
else:
raise NotImplementedError("FlashCohere is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",