From 31b23f98ff49008a4ac0517e268b1863d97c65ac Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 10 Jan 2024 09:42:26 -0500 Subject: [PATCH] feat: boilerplate phi2 model integration --- .../text_generation_server/models/__init__.py | 10 +++ server/text_generation_server/models/phi2.py | 76 +++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 server/text_generation_server/models/phi2.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 39d1d58e..f7dc9ddc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.phi2 import Phi2 # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -40,6 +41,7 @@ __all__ = [ "OPTSharded", "T5Sharded", "get_model", + "Phi2", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -201,6 +203,14 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == "phi-msft": + return Phi2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif model_type == "gpt_neox": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/phi2.py b/server/text_generation_server/models/phi2.py new file mode 100644 index 00000000..347fda80 --- /dev/null +++ b/server/text_generation_server/models/phi2.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.distributed + +from typing import Optional, List, Tuple, Type +from text_generation_server.models.types import Generation, Tokens +from text_generation_server.models.causal_lm import CausalLMBatch +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase, AutoModelForCausalLM +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 + +class Phi2(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + tokenizer.pad_token = tokenizer.eos_token + with device: + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code + ) + + # debug show the model + print(model) + + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + def decode(self, generated_ids: List[int]) -> str: + print("🔍 Decoding", generated_ids.shape) + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): + print("🔥 Forwarding", input_ids.shape) + default = super().forward(input_ids, attention_mask, position_ids, past_key_values) + return default + + + def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], CausalLMBatch | None, Tuple[int, int]]: + print("🛥️ Generating Tokens") + default = super().generate_token(batch) + return default \ No newline at end of file