Fixing cohere tokenizer. (#1697)
This commit is contained in:
parent
5062fda4ff
commit
f9958ee191
|
@ -3,7 +3,7 @@ import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
|
@ -36,7 +36,7 @@ class FlashCohere(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
raise NotImplementedError("FlashCohere is only available on GPU")
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
|
|
Loading…
Reference in New Issue