Add AutoCausalLM (#5)
Currently `BLOOMSharded` is a subclass of `CausalLM`, while it skips `CausalLM`'s constructor. This is a supprising behavior that we might want to avoid. This PR extract `CausalLM`'s constructor to `AutoCausalLM` to detect settings from `model_id`, so that we don't have to skip `CausalLM`'s constructor.
This commit is contained in:
parent
9048a80f8f
commit
8a5f80bb61
|
@ -5,12 +5,12 @@ from copy import copy
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||
from text_generation_server.models.causal_lm import AutoCausalLM, CausalLMBatch
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_causal_lm():
|
||||
return CausalLM("gpt2")
|
||||
return AutoCausalLM("gpt2")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -7,7 +7,7 @@ from transformers.models.auto import modeling_auto
|
|||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLM, AutoCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
|
@ -33,6 +33,7 @@ __all__ = [
|
|||
"Model",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"AutoCausalLM",
|
||||
"FlashCausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
|
@ -172,7 +173,7 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return AutoCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -192,7 +193,7 @@ def get_model(
|
|||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return AutoCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -257,7 +258,7 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
return AutoCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -276,7 +277,7 @@ def get_model(
|
|||
auto_map = config_dict.get("auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM(
|
||||
return AutoCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
|
|
@ -82,7 +82,7 @@ class BLOOMSharded(CausalLM):
|
|||
model = BloomForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -449,62 +449,6 @@ class CausalLMBatch(Batch):
|
|||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
|
@ -676,3 +620,60 @@ class CausalLM(Model):
|
|||
batch.past_key_values = past
|
||||
|
||||
return generations, batch
|
||||
|
||||
class AutoCausalLM(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
|
|
@ -197,7 +197,7 @@ class GalacticaSharded(CausalLM):
|
|||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -62,7 +62,7 @@ class GPTNeoxSharded(CausalLM):
|
|||
model = GPTNeoxForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -85,7 +85,7 @@ class MPTSharded(CausalLM):
|
|||
model = MPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
|
|
|
@ -60,7 +60,7 @@ class OPTSharded(CausalLM):
|
|||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -55,7 +55,7 @@ class RW(CausalLM):
|
|||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -60,7 +60,7 @@ class SantaCoder(CausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
Loading…
Reference in New Issue