Add AutoCausalLM (#5)

Currently `BLOOMSharded` is a subclass of `CausalLM`, while it skips `CausalLM`'s constructor. This is a supprising behavior that we might want to avoid.

This PR extract `CausalLM`'s constructor to `AutoCausalLM` to detect settings from `model_id`, so that we don't have to skip `CausalLM`'s constructor.
This commit is contained in:
Yang, Bo 2023-08-02 09:35:40 -07:00 committed by GitHub
parent 9048a80f8f
commit 8a5f80bb61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 70 deletions

View File

@ -5,12 +5,12 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.models.causal_lm import AutoCausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
return AutoCausalLM("gpt2")
@pytest.fixture(scope="session")

View File

@ -7,7 +7,7 @@ from transformers.models.auto import modeling_auto
from typing import Optional
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.causal_lm import CausalLM, AutoCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
@ -33,6 +33,7 @@ __all__ = [
"Model",
"BLOOMSharded",
"CausalLM",
"AutoCausalLM",
"FlashCausalLM",
"GalacticaSharded",
"Seq2SeqLM",
@ -172,7 +173,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
@ -192,7 +193,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
@ -257,7 +258,7 @@ def get_model(
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
@ -276,7 +277,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,

View File

@ -82,7 +82,7 @@ class BLOOMSharded(CausalLM):
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -449,62 +449,6 @@ class CausalLMBatch(Batch):
class CausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
@ -676,3 +620,60 @@ class CausalLM(Model):
batch.past_key_values = past
return generations, batch
class AutoCausalLM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -197,7 +197,7 @@ class GalacticaSharded(CausalLM):
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -62,7 +62,7 @@ class GPTNeoxSharded(CausalLM):
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -85,7 +85,7 @@ class MPTSharded(CausalLM):
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,

View File

@ -60,7 +60,7 @@ class OPTSharded(CausalLM):
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -55,7 +55,7 @@ class RW(CausalLM):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -60,7 +60,7 @@ class SantaCoder(CausalLM):
trust_remote_code=trust_remote_code,
)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,