support deepspeed
This commit is contained in:
parent
39df4d9975
commit
f11965c11d
|
@ -1,10 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Optional, Dict
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
from transformers.modeling_utils import no_init_weights
|
from transformers.modeling_utils import no_init_weights
|
||||||
|
|
||||||
|
@ -41,7 +44,7 @@ class Batch:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||||
) -> "Batch":
|
) -> "Batch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
|
@ -130,8 +133,8 @@ class Batch:
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
input_ids["attention_mask"][
|
input_ids["attention_mask"][
|
||||||
start_index:end_index, -batch.max_sequence_length :
|
start_index:end_index, -batch.max_sequence_length:
|
||||||
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
|
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:]
|
||||||
|
|
||||||
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
||||||
past_keys = past[0]
|
past_keys = past[0]
|
||||||
|
@ -177,12 +180,12 @@ class Batch:
|
||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
input_ids["past_key_values"][j][0][
|
input_ids["past_key_values"][j][0][
|
||||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
start_index:end_index, :, :, -(batch.max_sequence_length - 1):
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
|
||||||
|
|
||||||
input_ids["past_key_values"][j][1][
|
input_ids["past_key_values"][j][1][
|
||||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
|
||||||
|
|
||||||
# If we are on the last batch, we need to reshape the tensors
|
# If we are on the last batch, we need to reshape the tensors
|
||||||
if (i + 1) == len(batches):
|
if (i + 1) == len(batches):
|
||||||
|
@ -239,7 +242,7 @@ class BLOOM:
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Batch
|
self, batch: Batch
|
||||||
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.forward(**batch.input_ids)
|
outputs = self.forward(**batch.input_ids)
|
||||||
|
@ -269,11 +272,11 @@ class BLOOM:
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_tokens,
|
all_tokens,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
||||||
|
@ -450,3 +453,66 @@ class BLOOMSharded(BLOOM):
|
||||||
|
|
||||||
outputs.logits = logits
|
outputs.logits = logits
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BLOOMDeepSpeed(BLOOM):
|
||||||
|
def __init__(self, model_name):
|
||||||
|
super(BLOOM, self).__init__()
|
||||||
|
|
||||||
|
import deepspeed
|
||||||
|
from deepspeed.comm import init_distributed, get_rank, get_world_size
|
||||||
|
|
||||||
|
init_distributed("nccl")
|
||||||
|
self.rank = get_rank()
|
||||||
|
self.world_size = get_world_size()
|
||||||
|
self.master = self.rank == 0
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.device(f"cuda:{self.rank}")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
raise ValueError("DeepSpeed only supports CUDA")
|
||||||
|
|
||||||
|
model_root_dir = snapshot_download(
|
||||||
|
model_name, allow_patterns=["*"], local_files_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_name == "microsoft/bloom-deepspeed-inference-fp16":
|
||||||
|
if self.world_size != 8:
|
||||||
|
raise ValueError("microsoft/bloom-deepspeed-inference-fp16 only supports 8 GPUs")
|
||||||
|
checkpoints_json_path = Path(model_root_dir) / "ds_inference_config.json"
|
||||||
|
data = json.load(checkpoints_json_path.open("r"))
|
||||||
|
for key in data["checkpoints"].keys():
|
||||||
|
for i, v in enumerate(data["checkpoints"][key]):
|
||||||
|
data["checkpoints"][key][i] = str(Path(model_root_dir) / v)
|
||||||
|
else:
|
||||||
|
file_list = [str(entry) for entry in Path(model_root_dir).rglob("*.[bp][it][n]") if entry.is_file()]
|
||||||
|
data = {"type": "BLOOM", "checkpoints": file_list, "version": 1.0}
|
||||||
|
|
||||||
|
checkpoints_json = Path("checkpoints.json")
|
||||||
|
with checkpoints_json.open("w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
|
||||||
|
with deepspeed.OnDevice(dtype=dtype, device="meta", enabled=True):
|
||||||
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
model = deepspeed.init_inference(
|
||||||
|
model=model,
|
||||||
|
mp_size=self.world_size,
|
||||||
|
dtype=dtype,
|
||||||
|
replace_with_kernel_inject=True,
|
||||||
|
replace_method="auto",
|
||||||
|
enable_cuda_graph=False,
|
||||||
|
checkpoint=str(checkpoints_json),
|
||||||
|
mpu=None,
|
||||||
|
args=None,
|
||||||
|
training_mp_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model.module
|
||||||
|
self.num_heads = config.n_head // self.world_size
|
||||||
|
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from bloom_inference.cache import Cache
|
from bloom_inference.cache import Cache
|
||||||
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
|
from bloom_inference.model import BLOOM, Batch, BLOOMDeepSpeed
|
||||||
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
|
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,9 +116,9 @@ def serve(model_name, sharded, shard_directory):
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix:///tmp/bloom-inference-{}"
|
unix_socket_template = "unix:///tmp/bloom-inference-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
if shard_directory is None:
|
# if shard_directory is None:
|
||||||
raise ValueError("shard_directory must be set when sharded is True")
|
# raise ValueError("shard_directory must be set when sharded is True")
|
||||||
model = BLOOMSharded(model_name, shard_directory)
|
model = BLOOMDeepSpeed(model_name)
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(rank) for rank in range(model.world_size)
|
unix_socket_template.format(rank) for rank in range(model.world_size)
|
||||||
]
|
]
|
||||||
|
|
|
@ -197,7 +197,7 @@ python-versions = ">=3.7"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25"
|
content-hash = "ed2061d45e5885616d5424a01e5ac8c57b9a52fd758a509a9c17406656ce5156"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
accelerate = [
|
accelerate = [
|
||||||
|
|
|
@ -8,7 +8,6 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
protobuf = "^4.21.7"
|
protobuf = "^4.21.7"
|
||||||
grpcio = "^1.49.1"
|
grpcio = "^1.49.1"
|
||||||
torch = "^1.12.1"
|
|
||||||
typer = "^0.6.1"
|
typer = "^0.6.1"
|
||||||
grpcio-reflection = "^1.49.1"
|
grpcio-reflection = "^1.49.1"
|
||||||
accelerate = "^0.12.0"
|
accelerate = "^0.12.0"
|
||||||
|
|
Loading…
Reference in New Issue