From f11965c11dde7ae35ac438b10fbf2b83b4a43c45 Mon Sep 17 00:00:00 2001 From: Olivier Dehaene Date: Thu, 13 Oct 2022 11:05:44 +0200 Subject: [PATCH] support deepspeed --- server/bloom_inference/model.py | 92 +++++++++++++++++++++++++++----- server/bloom_inference/server.py | 8 +-- server/poetry.lock | 2 +- server/pyproject.toml | 1 - 4 files changed, 84 insertions(+), 19 deletions(-) diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index 21cf1154..ef5df7fd 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -1,10 +1,13 @@ import torch import torch.distributed +import io +import json from dataclasses import dataclass from pathlib import Path from typing import List, Tuple, Optional, Dict +from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers.modeling_utils import no_init_weights @@ -41,7 +44,7 @@ class Batch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "Batch": inputs = [] next_token_choosers = [] @@ -130,8 +133,8 @@ class Batch: # We need to slice the attention mask to remove padding from previous steps input_ids["attention_mask"][ - start_index:end_index, -batch.max_sequence_length : - ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] + start_index:end_index, -batch.max_sequence_length: + ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:] for j, past in enumerate(batch.input_ids["past_key_values"]): past_keys = past[0] @@ -177,12 +180,12 @@ class Batch: # We slice the past keys and values to remove the padding from previous batches input_ids["past_key_values"][j][0][ - start_index:end_index, :, :, -(batch.max_sequence_length - 1) : - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] + start_index:end_index, :, :, -(batch.max_sequence_length - 1): + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] input_ids["past_key_values"][j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1) :, : - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + start_index:end_index, :, -(batch.max_sequence_length - 1):, : + ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] # If we are on the last batch, we need to reshape the tensors if (i + 1) == len(batches): @@ -239,7 +242,7 @@ class BLOOM: ) def generate_token( - self, batch: Batch + self, batch: Batch ) -> Tuple[List[GeneratedText], Optional[Batch]]: with torch.no_grad(): outputs = self.forward(**batch.input_ids) @@ -269,11 +272,11 @@ class BLOOM: # For each member of the batch for i, ( - request, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, + request, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, ) in enumerate(iterator): # Select next token next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) @@ -450,3 +453,66 @@ class BLOOMSharded(BLOOM): outputs.logits = logits return outputs + + +class BLOOMDeepSpeed(BLOOM): + def __init__(self, model_name): + super(BLOOM, self).__init__() + + import deepspeed + from deepspeed.comm import init_distributed, get_rank, get_world_size + + init_distributed("nccl") + self.rank = get_rank() + self.world_size = get_world_size() + self.master = self.rank == 0 + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{self.rank}") + dtype = torch.float16 + else: + raise ValueError("DeepSpeed only supports CUDA") + + model_root_dir = snapshot_download( + model_name, allow_patterns=["*"], local_files_only=True + ) + + if model_name == "microsoft/bloom-deepspeed-inference-fp16": + if self.world_size != 8: + raise ValueError("microsoft/bloom-deepspeed-inference-fp16 only supports 8 GPUs") + checkpoints_json_path = Path(model_root_dir) / "ds_inference_config.json" + data = json.load(checkpoints_json_path.open("r")) + for key in data["checkpoints"].keys(): + for i, v in enumerate(data["checkpoints"][key]): + data["checkpoints"][key][i] = str(Path(model_root_dir) / v) + else: + file_list = [str(entry) for entry in Path(model_root_dir).rglob("*.[bp][it][n]") if entry.is_file()] + data = {"type": "BLOOM", "checkpoints": file_list, "version": 1.0} + + checkpoints_json = Path("checkpoints.json") + with checkpoints_json.open("w") as f: + json.dump(data, f) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + config = AutoConfig.from_pretrained(model_name) + + with deepspeed.OnDevice(dtype=dtype, device="meta", enabled=True): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) + + model = model.eval() + + model = deepspeed.init_inference( + model=model, + mp_size=self.world_size, + dtype=dtype, + replace_with_kernel_inject=True, + replace_method="auto", + enable_cuda_graph=False, + checkpoint=str(checkpoints_json), + mpu=None, + args=None, + training_mp_size=1 + ) + + self.model = model.module + self.num_heads = config.n_head // self.world_size diff --git a/server/bloom_inference/server.py b/server/bloom_inference/server.py index 3a509169..bee99f10 100644 --- a/server/bloom_inference/server.py +++ b/server/bloom_inference/server.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Optional, List from bloom_inference.cache import Cache -from bloom_inference.model import BLOOM, Batch, BLOOMSharded +from bloom_inference.model import BLOOM, Batch, BLOOMDeepSpeed from bloom_inference.pb import generate_pb2_grpc, generate_pb2 @@ -116,9 +116,9 @@ def serve(model_name, sharded, shard_directory): ): unix_socket_template = "unix:///tmp/bloom-inference-{}" if sharded: - if shard_directory is None: - raise ValueError("shard_directory must be set when sharded is True") - model = BLOOMSharded(model_name, shard_directory) + # if shard_directory is None: + # raise ValueError("shard_directory must be set when sharded is True") + model = BLOOMDeepSpeed(model_name) server_urls = [ unix_socket_template.format(rank) for rank in range(model.world_size) ] diff --git a/server/poetry.lock b/server/poetry.lock index ea20ef26..c73a8494 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -197,7 +197,7 @@ python-versions = ">=3.7" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25" +content-hash = "ed2061d45e5885616d5424a01e5ac8c57b9a52fd758a509a9c17406656ce5156" [metadata.files] accelerate = [ diff --git a/server/pyproject.toml b/server/pyproject.toml index 9d14ce6c..2958e1a8 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -8,7 +8,6 @@ authors = ["Olivier Dehaene "] python = "^3.9" protobuf = "^4.21.7" grpcio = "^1.49.1" -torch = "^1.12.1" typer = "^0.6.1" grpcio-reflection = "^1.49.1" accelerate = "^0.12.0"