39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
|
import torch
|
||
|
|
||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
|
from typing import Optional, Tuple, List
|
||
|
|
||
|
from text_generation.models import Model
|
||
|
|
||
|
|
||
|
class CausalLM(Model):
|
||
|
def __init__(self, model_name: str):
|
||
|
if torch.cuda.is_available():
|
||
|
device = torch.device("cuda")
|
||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||
|
else:
|
||
|
device = torch.device("cpu")
|
||
|
dtype = torch.float32
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||
|
model_name,
|
||
|
torch_dtype=dtype,
|
||
|
device_map="auto" if torch.cuda.is_available() else None,
|
||
|
).eval()
|
||
|
|
||
|
super(CausalLM, self).__init__(tokenizer=tokenizer, num_heads=self.model.config.num_attention_heads, device=device)
|
||
|
|
||
|
def forward(
|
||
|
self, input_ids, attention_mask, past_key_values: Optional = None
|
||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||
|
# Model Forward
|
||
|
outputs = self.model.forward(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=True,
|
||
|
)
|
||
|
return outputs.logits, outputs.past_key_values
|