hf_text-generation-inference/server/text_generation/models/santacoder.py

59 lines
1.8 KiB
Python
Raw Normal View History

2023-01-20 04:24:39 -07:00
import torch
import torch.distributed
from typing import Optional, List, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(self, model_name: str, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
EOD,
FIM_PREFIX,
FIM_MIDDLE,
FIM_SUFFIX,
FIM_PAD,
],
"pad_token": EOD,
}
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
2023-01-20 07:35:22 -07:00
).to(device).eval()
2023-01-20 04:24:39 -07:00
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)