support deepspeed
This commit is contained in:
parent
39df4d9975
commit
f11965c11d
|
@ -1,10 +1,13 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
import io
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
|
@ -450,3 +453,66 @@ class BLOOMSharded(BLOOM):
|
|||
|
||||
outputs.logits = logits
|
||||
return outputs
|
||||
|
||||
|
||||
class BLOOMDeepSpeed(BLOOM):
|
||||
def __init__(self, model_name):
|
||||
super(BLOOM, self).__init__()
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.comm import init_distributed, get_rank, get_world_size
|
||||
|
||||
init_distributed("nccl")
|
||||
self.rank = get_rank()
|
||||
self.world_size = get_world_size()
|
||||
self.master = self.rank == 0
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device(f"cuda:{self.rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise ValueError("DeepSpeed only supports CUDA")
|
||||
|
||||
model_root_dir = snapshot_download(
|
||||
model_name, allow_patterns=["*"], local_files_only=True
|
||||
)
|
||||
|
||||
if model_name == "microsoft/bloom-deepspeed-inference-fp16":
|
||||
if self.world_size != 8:
|
||||
raise ValueError("microsoft/bloom-deepspeed-inference-fp16 only supports 8 GPUs")
|
||||
checkpoints_json_path = Path(model_root_dir) / "ds_inference_config.json"
|
||||
data = json.load(checkpoints_json_path.open("r"))
|
||||
for key in data["checkpoints"].keys():
|
||||
for i, v in enumerate(data["checkpoints"][key]):
|
||||
data["checkpoints"][key][i] = str(Path(model_root_dir) / v)
|
||||
else:
|
||||
file_list = [str(entry) for entry in Path(model_root_dir).rglob("*.[bp][it][n]") if entry.is_file()]
|
||||
data = {"type": "BLOOM", "checkpoints": file_list, "version": 1.0}
|
||||
|
||||
checkpoints_json = Path("checkpoints.json")
|
||||
with checkpoints_json.open("w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
|
||||
with deepspeed.OnDevice(dtype=dtype, device="meta", enabled=True):
|
||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
model = deepspeed.init_inference(
|
||||
model=model,
|
||||
mp_size=self.world_size,
|
||||
dtype=dtype,
|
||||
replace_with_kernel_inject=True,
|
||||
replace_method="auto",
|
||||
enable_cuda_graph=False,
|
||||
checkpoint=str(checkpoints_json),
|
||||
mpu=None,
|
||||
args=None,
|
||||
training_mp_size=1
|
||||
)
|
||||
|
||||
self.model = model.module
|
||||
self.num_heads = config.n_head // self.world_size
|
||||
|
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from typing import Optional, List
|
||||
|
||||
from bloom_inference.cache import Cache
|
||||
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
|
||||
from bloom_inference.model import BLOOM, Batch, BLOOMDeepSpeed
|
||||
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
|
||||
|
||||
|
||||
|
@ -116,9 +116,9 @@ def serve(model_name, sharded, shard_directory):
|
|||
):
|
||||
unix_socket_template = "unix:///tmp/bloom-inference-{}"
|
||||
if sharded:
|
||||
if shard_directory is None:
|
||||
raise ValueError("shard_directory must be set when sharded is True")
|
||||
model = BLOOMSharded(model_name, shard_directory)
|
||||
# if shard_directory is None:
|
||||
# raise ValueError("shard_directory must be set when sharded is True")
|
||||
model = BLOOMDeepSpeed(model_name)
|
||||
server_urls = [
|
||||
unix_socket_template.format(rank) for rank in range(model.world_size)
|
||||
]
|
||||
|
|
|
@ -197,7 +197,7 @@ python-versions = ">=3.7"
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25"
|
||||
content-hash = "ed2061d45e5885616d5424a01e5ac8c57b9a52fd758a509a9c17406656ce5156"
|
||||
|
||||
[metadata.files]
|
||||
accelerate = [
|
||||
|
|
|
@ -8,7 +8,6 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
|||
python = "^3.9"
|
||||
protobuf = "^4.21.7"
|
||||
grpcio = "^1.49.1"
|
||||
torch = "^1.12.1"
|
||||
typer = "^0.6.1"
|
||||
grpcio-reflection = "^1.49.1"
|
||||
accelerate = "^0.12.0"
|
||||
|
|
Loading…
Reference in New Issue