feat(server): support OPT models (#55)

OPT models do not all have a `tokenizer.json` file on the hub at the
moment. Can't merge for now.
This commit is contained in:
OlivierDehaene 2023-04-11 19:16:41 +02:00 committed by GitHub
parent 299217c95c
commit f26dfd0dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 270 additions and 59 deletions

View File

@ -54,11 +54,12 @@ to power LLMs api-inference widgets.
## Optimized architectures
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Llama](https://github.com/facebookresearch/llama)
- [OPT](https://huggingface.co/facebook/opt-66b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
Other architectures are supported on a best effort basis using:

View File

@ -349,8 +349,8 @@ fn main() -> ExitCode {
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
Ok(ShardStatus::Failed(rank)) => {
tracing::error!("Shard {} failed to start.", rank);
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
@ -457,8 +457,8 @@ fn main() -> ExitCode {
let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}");
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed.");
exit_code = ExitCode::FAILURE;
break;
};
@ -488,7 +488,7 @@ fn main() -> ExitCode {
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
Failed(usize),
}
#[allow(clippy::too_many_arguments)]
@ -627,9 +627,7 @@ fn shard_manager(
tracing::error!("Please install it with `make install-server`")
}
}
status_sender
.send(ShardStatus::Failed((rank, err.to_string())))
.unwrap();
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}
};
@ -658,11 +656,7 @@ fn shard_manager(
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}

View File

@ -1,4 +1,3 @@
import os
import torch
from loguru import logger
@ -11,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
@ -36,7 +36,11 @@ __all__ = [
"GalacticaSharded",
"GPTNeoxSharded",
"Seq2SeqLM",
"Galactica",
"GalacticaSharded",
"SantaCoder",
"OPT",
"OPTSharded",
"T5Sharded",
"get_model",
]
@ -48,9 +52,11 @@ if FLASH_ATTENTION:
__all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \
"or install flash attention with `cd server && make install install-flash-attention`"
FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
@ -64,7 +70,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
@ -100,13 +106,17 @@ def get_model(
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(model_id, revision, quantize=quantize)
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(model_id, revision, quantize=quantize)
if config.model_type == "opt":
if sharded:
return OPTSharded(model_id, revision, quantize=quantize)
else:
return OPT(model_id, revision, quantize=quantize)
if model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)

View File

@ -62,7 +62,7 @@ class BLOOMSharded(BLOOM):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
@ -122,18 +122,11 @@ class BLOOMSharded(BLOOM):
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
if param_name == "weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]

View File

@ -19,8 +19,9 @@ from transformers.models.opt.parallel_layers import (
)
from text_generation_server.models import CausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
@ -158,7 +159,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
)
class Galactica(CausalLM):
class Galactica(OPT):
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
@ -192,7 +193,7 @@ class GalacticaSharded(Galactica):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
@ -253,18 +254,11 @@ class GalacticaSharded(Galactica):
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
if param_name == "weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]

View File

@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -0,0 +1,224 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
)
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class OPT(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class OPTSharded(OPT):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
)
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values

View File

@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -50,7 +50,6 @@ def try_to_load_from_cache(
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
@ -59,10 +58,6 @@ def try_to_load_from_cache(
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return None
# Check if revision folder exists
if not snapshots_dir.exists():
return None