support deepspeed

This commit is contained in:
Olivier Dehaene 2022-10-13 11:05:44 +02:00
parent 39df4d9975
commit f11965c11d
4 changed files with 84 additions and 19 deletions

View File

@ -1,10 +1,13 @@
import torch import torch
import torch.distributed import torch.distributed
import io
import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Optional, Dict from typing import List, Tuple, Optional, Dict
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights
@ -41,7 +44,7 @@ class Batch:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch": ) -> "Batch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -130,8 +133,8 @@ class Batch:
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][ input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length : start_index:end_index, -batch.max_sequence_length:
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:]
for j, past in enumerate(batch.input_ids["past_key_values"]): for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0] past_keys = past[0]
@ -177,12 +180,12 @@ class Batch:
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][ input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) : start_index:end_index, :, :, -(batch.max_sequence_length - 1):
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
input_ids["past_key_values"][j][1][ input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, : start_index:end_index, :, -(batch.max_sequence_length - 1):, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
# If we are on the last batch, we need to reshape the tensors # If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches): if (i + 1) == len(batches):
@ -239,7 +242,7 @@ class BLOOM:
) )
def generate_token( def generate_token(
self, batch: Batch self, batch: Batch
) -> Tuple[List[GeneratedText], Optional[Batch]]: ) -> Tuple[List[GeneratedText], Optional[Batch]]:
with torch.no_grad(): with torch.no_grad():
outputs = self.forward(**batch.input_ids) outputs = self.forward(**batch.input_ids)
@ -269,11 +272,11 @@ class BLOOM:
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_tokens, all_tokens,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
@ -450,3 +453,66 @@ class BLOOMSharded(BLOOM):
outputs.logits = logits outputs.logits = logits
return outputs return outputs
class BLOOMDeepSpeed(BLOOM):
def __init__(self, model_name):
super(BLOOM, self).__init__()
import deepspeed
from deepspeed.comm import init_distributed, get_rank, get_world_size
init_distributed("nccl")
self.rank = get_rank()
self.world_size = get_world_size()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.float16
else:
raise ValueError("DeepSpeed only supports CUDA")
model_root_dir = snapshot_download(
model_name, allow_patterns=["*"], local_files_only=True
)
if model_name == "microsoft/bloom-deepspeed-inference-fp16":
if self.world_size != 8:
raise ValueError("microsoft/bloom-deepspeed-inference-fp16 only supports 8 GPUs")
checkpoints_json_path = Path(model_root_dir) / "ds_inference_config.json"
data = json.load(checkpoints_json_path.open("r"))
for key in data["checkpoints"].keys():
for i, v in enumerate(data["checkpoints"][key]):
data["checkpoints"][key][i] = str(Path(model_root_dir) / v)
else:
file_list = [str(entry) for entry in Path(model_root_dir).rglob("*.[bp][it][n]") if entry.is_file()]
data = {"type": "BLOOM", "checkpoints": file_list, "version": 1.0}
checkpoints_json = Path("checkpoints.json")
with checkpoints_json.open("w") as f:
json.dump(data, f)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
config = AutoConfig.from_pretrained(model_name)
with deepspeed.OnDevice(dtype=dtype, device="meta", enabled=True):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
model = model.eval()
model = deepspeed.init_inference(
model=model,
mp_size=self.world_size,
dtype=dtype,
replace_with_kernel_inject=True,
replace_method="auto",
enable_cuda_graph=False,
checkpoint=str(checkpoints_json),
mpu=None,
args=None,
training_mp_size=1
)
self.model = model.module
self.num_heads = config.n_head // self.world_size

View File

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Optional, List from typing import Optional, List
from bloom_inference.cache import Cache from bloom_inference.cache import Cache
from bloom_inference.model import BLOOM, Batch, BLOOMSharded from bloom_inference.model import BLOOM, Batch, BLOOMDeepSpeed
from bloom_inference.pb import generate_pb2_grpc, generate_pb2 from bloom_inference.pb import generate_pb2_grpc, generate_pb2
@ -116,9 +116,9 @@ def serve(model_name, sharded, shard_directory):
): ):
unix_socket_template = "unix:///tmp/bloom-inference-{}" unix_socket_template = "unix:///tmp/bloom-inference-{}"
if sharded: if sharded:
if shard_directory is None: # if shard_directory is None:
raise ValueError("shard_directory must be set when sharded is True") # raise ValueError("shard_directory must be set when sharded is True")
model = BLOOMSharded(model_name, shard_directory) model = BLOOMDeepSpeed(model_name)
server_urls = [ server_urls = [
unix_socket_template.format(rank) for rank in range(model.world_size) unix_socket_template.format(rank) for rank in range(model.world_size)
] ]

2
server/poetry.lock generated
View File

@ -197,7 +197,7 @@ python-versions = ">=3.7"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25" content-hash = "ed2061d45e5885616d5424a01e5ac8c57b9a52fd758a509a9c17406656ce5156"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [

View File

@ -8,7 +8,6 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
python = "^3.9" python = "^3.9"
protobuf = "^4.21.7" protobuf = "^4.21.7"
grpcio = "^1.49.1" grpcio = "^1.49.1"
torch = "^1.12.1"
typer = "^0.6.1" typer = "^0.6.1"
grpcio-reflection = "^1.49.1" grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0" accelerate = "^0.12.0"