2022-11-04 07:22:47 -06:00
|
|
|
import torch
|
|
|
|
|
2022-11-03 09:07:54 -06:00
|
|
|
from abc import ABC, abstractmethod
|
2022-11-04 11:03:04 -06:00
|
|
|
from typing import List, Tuple, Optional, TypeVar, Type
|
2023-01-17 01:10:22 -07:00
|
|
|
from transformers import PreTrainedTokenizerBase
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
|
|
from text_generation.models.types import Batch, GeneratedText
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
B = TypeVar("B", bound=Batch)
|
|
|
|
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2022-11-03 09:07:54 -06:00
|
|
|
class Model(ABC):
|
2023-01-17 01:10:22 -07:00
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
2022-11-04 07:22:47 -06:00
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.device = device
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
@property
|
2022-11-03 09:07:54 -06:00
|
|
|
@abstractmethod
|
2022-11-04 11:03:04 -06:00
|
|
|
def batch_type(self) -> Type[B]:
|
2022-11-03 09:07:54 -06:00
|
|
|
raise NotImplementedError
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
@abstractmethod
|
|
|
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
|
|
|
raise NotImplementedError
|