From 60ed7b535c314b27b1848b16afbe6ec923134f05 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Thu, 23 Feb 2023 09:52:17 +0100 Subject: [PATCH] first tests --- .../models/conf/megatron_gpt_inference.yaml | 36 ++ .../text_generation/models/megatron_nemo.py | 528 ++++++++++++++++++ 2 files changed, 564 insertions(+) create mode 100644 server/text_generation/models/conf/megatron_gpt_inference.yaml create mode 100644 server/text_generation/models/megatron_nemo.py diff --git a/server/text_generation/models/conf/megatron_gpt_inference.yaml b/server/text_generation/models/conf/megatron_gpt_inference.yaml new file mode 100644 index 00000000..661dd719 --- /dev/null +++ b/server/text_generation/models/conf/megatron_gpt_inference.yaml @@ -0,0 +1,36 @@ +inference: + greedy: False # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: True # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: 16 # 16, 32, or bf16 + +tensor_model_parallel_size: 1 +pipeline_model_parallel_size: 1 +pipeline_model_parallel_split_rank: 0 # used for encoder and decoder model +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +prompts: # prompts for GPT inference + - "Q: How are you?" + - "Q: How big is the universe?" +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: False # whether create a public URL +username: test # user name for web client +password: test2 # password for web client \ No newline at end of file diff --git a/server/text_generation/models/megatron_nemo.py b/server/text_generation/models/megatron_nemo.py new file mode 100644 index 00000000..c5400e0b --- /dev/null +++ b/server/text_generation/models/megatron_nemo.py @@ -0,0 +1,528 @@ +# Copyright (c) 2021, NVIDIA CORPORATION and Hugging Face authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import threading + +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from accelerate import init_empty_weights +from safetensors import safe_open +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, +) +from transformers.models.gpt_neox.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + +from text_generation.models import CausalLM +from text_generation.utils import ( + initialize_torch_distributed, + weight_files, +) + +from omegaconf import OmegaConf, open_dict +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.modules.common.megatron_web_server import get_demo +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer +from nemo.collections.nlp.modules.common.text_generation_utils import generate +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.core.config import hydra_runner +from nemo.utils.app_state import AppState +from nemo.utils.model_utils import inject_model_parallel_rank + +try: + from apex.transformer import parallel_state + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +""" +This is the script to run GPT text generation. + +Usage: + Assume the model has TP=1, PP=1 in the following use cases. + a. run greedy inference from a nemo file: + python megatron_gpt_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + prompts=[prompt1,prompt2] + + b. run greedy inference from a PTL checkpoint file: + python megatron_gpt_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT_FILE \ + checkpoint_name=CHECKPOINT_FILE_NAME \ + hparams_file=HPARAMS_FILE \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + prompts=[prompt1,prompt2] + + c. run top_p inference from a nemo file: + python megatron_gpt_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=False \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.repetition_penalty=1.2 \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + prompts=[prompt1,prompt2] + + d. If you don't need to generate tokens and need model to compute logprobs: + python megatron_gpt_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.compute_logprob=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + prompts=[text to get logprob] + + e. Launch the inference server + python megatron_gpt_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + server=True + + To send a request to the server, here is one example code: + ```python + import json + import requests + + batch_size = 8 + port_num = 5555 + headers = {"Content-Type": "application/json"} + + + def request_data(data): + resp = requests.put('http://localhost:{}/generate'.format(port_num), + data=json.dumps(data), + headers=headers) + sentences = resp.json()['sentences'] + return sentences + + + data = { + "sentences": [""] * batch_size, + "tokens_to_generate": 300, + "temperature": 1.0, + "add_BOS": True, + "top_k": 0, + "top_p": 0.9, + "greedy": False, + "all_probs": False, + "repetition_penalty": 1.2, + "min_tokens_to_generate": 2, + } + + sentences = request_data(data) + ``` +""" + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +class RequestDataSet(Dataset): + def __init__(self, sentences): + super().__init__() + self.sentences = sentences + + def __len__(self,): + return len(self.sentences) + + def __getitem__(self, idx): + return self.sentences[idx] + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") +def main(cfg) -> None: + + # trainer required for restoring model parallel models + trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) + assert ( + cfg.trainer.devices * cfg.trainer.num_nodes + == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size + ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" + + if cfg.gpt_model_file: + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.gpt_model_file): + save_restore_connector.model_extracted_dir = cfg.gpt_model_file + + pretrained_cfg = MegatronGPTModel.restore_from( + restore_path=cfg.gpt_model_file, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + OmegaConf.set_struct(pretrained_cfg, True) + with open_dict(pretrained_cfg): + pretrained_cfg.sequence_parallel = False + pretrained_cfg.activations_checkpoint_granularity = None + pretrained_cfg.activations_checkpoint_method = None + model = MegatronGPTModel.restore_from( + restore_path=cfg.gpt_model_file, + trainer=trainer, + override_config_path=pretrained_cfg, + save_restore_connector=save_restore_connector, + ) + elif cfg.checkpoint_dir: + app_state = AppState() + if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, + ) + checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) + model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) + else: + raise ValueError("need at least a nemo file or checkpoint dir") + + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + } + + # First method of running text generation, call model.generate method + response = model.generate( + inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params + ) + + print("***************************") + print(response) + print("***************************") + + # Second method of running text generation, call trainer.predict + ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + request_dl = DataLoader(dataset=ds, batch_size=2) + config = OmegaConf.to_container(cfg.inference) + model.set_inference_config(config) + response = trainer.predict(model, request_dl) + + print("***************************") + print(response) + print("***************************") + + # Third method of running text generation, use inference server + if cfg.server: + if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: + if cfg.web_server: + thread = threading.Thread(target=get_demo, daemon=True, args=(cfg.share, cfg.username, cfg.password)) + thread.start() + server = MegatronServer(model.cuda()) + server.run("0.0.0.0", port=cfg.port) + + while True: + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + if choice[0].item() == 0: + generate(model.cuda()) + + +class MegatronNemo(CausalLM): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(model_id): + save_restore_connector.model_extracted_dir = model_id + + pretrained_cfg = MegatronGPTModel.restore_from( + restore_path=model_id, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + OmegaConf.set_struct(pretrained_cfg, True) + with open_dict(pretrained_cfg): + pretrained_cfg.sequence_parallel = False + pretrained_cfg.activations_checkpoint_granularity = None + pretrained_cfg.activations_checkpoint_method = None + model = MegatronGPTModel.restore_from( + restore_path=model_id, + trainer=trainer, + override_config_path=pretrained_cfg, + save_restore_connector=save_restore_connector, + ) + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values + + +class GPTNeoxSharded(GPTNeox): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + tokenizer.pad_token = tokenizer.eos_token + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor, + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state): + def linear(input, weight, bias): + out = bnb.matmul( + input, + weight, + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out + + return linear + + module.linear = replace_linear(state) + + else: + tensor = tensor.to(device) + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): + if self.model.gpt_neox.tp_embeddings: + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather( + logits, outputs.logits, group=self.process_group + ) + logits = torch.cat(logits, dim=2) + + return logits, outputs.past_key_values + # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard + else: + return super(GPTNeoxSharded, self).forward( + input_ids, attention_mask, position_ids, past_key_values + )