feat(server): Support bitsandbytes

This commit is contained in:
OlivierDehaene 2022-10-27 14:25:29 +02:00
parent beb552127a
commit 09674e6df9
14 changed files with 221 additions and 47 deletions

View File

@ -26,6 +26,7 @@ ENV LANG=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
MODEL_BASE_PATH=/var/azureml-model \
MODEL_NAME=bigscience/bloom \
QUANTIZE=false \
NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
@ -72,4 +73,4 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
# Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS

View File

@ -18,5 +18,14 @@ router-dev:
run-bloom-560m:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
run-bloom-560m-quantize:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
download-bloom:
bloom-inference-server download-weights bigscience/bloom
run-bloom:
text-generation-launcher --model-name bigscience/bloom --num-shard 8
run-bloom-quantize:
text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize

View File

@ -8,22 +8,26 @@
A Rust and gRPC server for large language models text generation inference.
## Features
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Dynamic bathing of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB
## Supported models
- BLOOM
- BLOOM-560m
## Load Tests for BLOOM
See `k6/load_test.js`
We send the default examples with a 1 second delay between requests.
Stages:
- Ramp up to 50 vus in 1min
- Ramp up from 50 to 100 vus in 2min
- Ramp down to 0 vus in 1min
| | avg | min | med | max | p(90) | p(95) | RPS |
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
## Install
@ -33,10 +37,30 @@ make install
## Run
### BLOOM 560-m
```shell
make run-bloom-560m
```
### BLOOM
First you need to download the weights:
```shell
make download-bloom
```
```shell
make run-bloom # Requires 8xA100 80GB
```
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
```shell
make run-bloom-quantize # Requires 8xA100 40GB
```
## Test
```shell

View File

@ -21,6 +21,8 @@ struct Args {
model_name: String,
#[clap(long, env)]
num_shard: Option<usize>,
#[clap(long, env)]
quantize: bool,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
@ -46,6 +48,7 @@ fn main() -> ExitCode {
let Args {
model_name,
num_shard,
quantize,
max_concurrent_requests,
max_input_length,
max_batch_size,
@ -87,6 +90,7 @@ fn main() -> ExitCode {
thread::spawn(move || {
shard_manager(
model_name,
quantize,
uds_path,
rank,
num_shard,
@ -169,6 +173,8 @@ fn main() -> ExitCode {
tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-router`")
}
} else {
tracing::error!("{}", err);
}
shutdown_shards(shutdown, &shutdown_receiver);
@ -232,6 +238,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_name: String,
quantize: bool,
uds_path: String,
rank: usize,
world_size: usize,
@ -260,6 +267,10 @@ fn shard_manager(
shard_argv.push("--sharded".to_string());
}
if quantize {
shard_argv.push("--quantize".to_string())
}
let mut env = vec![
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
(
@ -338,12 +349,10 @@ fn shard_manager(
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready {
if wait_time.elapsed() > Duration::from_secs(10) {
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {} to be ready...", rank);
wait_time = Instant::now();
}
}
sleep(Duration::from_millis(100));
}
}

View File

@ -12,9 +12,7 @@ pub struct ShardedClient {
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self {
clients,
}
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with

View File

@ -1,7 +1,8 @@
/// Batching and inference logic
use crate::GenerateRequest;
use crate::{Db, Entry};
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
@ -213,10 +214,15 @@ pub enum InferError {
}
/// Convert to Axum supported format
impl From<InferError> for (StatusCode, String) {
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
InferError::GenerationError(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: err.to_string(),
}),
),
}
}
}

View File

@ -64,3 +64,8 @@ pub(crate) struct GenerateRequest {
pub(crate) struct GeneratedText {
pub generated_text: String,
}
#[derive(Serialize)]
pub(crate) struct ErrorResponse {
pub error: String,
}

View File

@ -1,4 +1,6 @@
use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation};
use crate::{
Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
@ -23,7 +25,7 @@ struct ServerState {
/// Health check method
#[instrument(skip(state), fields(time, time_per_token))]
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
@ -32,7 +34,9 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
"Model is overloaded".to_string(),
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
@ -70,13 +74,16 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
async fn generate(
state: Extension<ServerState>,
req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
tracing::error!("Model is overloaded");
(
StatusCode::TOO_MANY_REQUESTS,
"Model is overloaded".to_string(),
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
@ -88,10 +95,21 @@ async fn generate(
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await?;
.await
.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
// Inference
let response = state.batcher.infer(input_length, validated_request).await?;
let response = state
.batcher
.infer(input_length, validated_request)
.await
.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
// Timings
let total_time = start_time.elapsed();

View File

@ -1,6 +1,7 @@
/// Payload validation logic
use crate::GenerateRequest;
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::{
@ -146,20 +147,25 @@ type ValidationRequest = (
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Temperature must be strictly positive")]
#[error("temperature must be strictly positive")]
Temperature,
#[error("Top p must be >= 0.0 or < 1.0")]
#[error("top_p must be >= 0.0 or < 1.0")]
TopP,
#[error("Top k must be strictly positive")]
#[error("top_k must be strictly positive")]
TopK,
#[error("Max New Tokens must be <= 512")]
#[error("max_new_tokens must be <= 512")]
MaxNewTokens,
#[error("Inputs must have less than {1} tokens. Given: {0}")]
#[error("inputs must have less than {1} tokens. Given: {0}")]
InputLength(usize, usize),
}
impl From<ValidationError> for (StatusCode, String) {
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: err.to_string(),
}),
)
}
}

View File

@ -12,6 +12,7 @@ app = typer.Typer()
def serve(
model_name: str,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/bloom-inference",
):
if sharded:
@ -28,7 +29,7 @@ def serve(
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, sharded, uds_path)
server.serve(model_name, sharded, quantize, uds_path)
@app.command()

View File

@ -19,9 +19,16 @@ from bloom_inference.utils import (
NextTokenChooser,
initialize_torch_distributed,
weight_files,
download_weights
download_weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
@ -68,7 +75,9 @@ class Batch:
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8).to(device)
input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
@ -250,7 +259,7 @@ class BLOOM:
def generate_token(
self, batch: Batch
) -> Tuple[List[GeneratedText], Optional[Batch]]:
with torch.no_grad():
with torch.inference_mode():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
@ -374,13 +383,13 @@ class BLOOM:
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str):
def __init__(self, model_name: str, quantize: bool = False):
super(BLOOM, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
dtype = torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
@ -414,6 +423,7 @@ class BLOOMSharded(BLOOM):
self.load_weights(
model,
filenames,
quantize=quantize,
device=self.device,
rank=self.rank,
world_size=self.world_size,
@ -423,11 +433,18 @@ class BLOOMSharded(BLOOM):
@staticmethod
def load_weights(
model, filenames: List[str], device: torch.device, rank: int, world_size: int
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
full_name = f"transformer.{name}"
@ -479,6 +496,67 @@ class BLOOMSharded(BLOOM):
)
tensor = tensor.contiguous()
if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine"
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor.transpose(1, 0),
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state, in_features, out_features):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = torch.empty(
size_out, device=input.device, dtype=input.dtype
)
out = bnb.matmul(
input,
weight,
out=out.view(-1, out_features),
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out.view(size_out)
return linear
module.linear = replace_linear(
state, module.in_features, module.out_features
)
else:
tensor = tensor.to(device)
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor

View File

@ -68,21 +68,27 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_name: str,
sharded: bool,
quantize: bool,
uds_path: Path,
):
async def serve_inner(
model_name: str,
sharded: bool = False,
quantize: bool = False,
):
unix_socket_template = "unix://{}-{}"
if sharded:
model = BLOOMSharded(model_name)
model = BLOOMSharded(model_name, quantize)
server_urls = [
unix_socket_template.format(uds_path, rank)
for rank in range(model.world_size)
]
local_url = server_urls[model.rank]
else:
if quantize:
raise ValueError(
"bitsandbytes quantization is only available when running in `sharded` mode."
)
model = BLOOM(model_name)
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
@ -105,4 +111,4 @@ def serve(
print("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_name, sharded))
asyncio.run(serve_inner(model_name, sharded, quantize))

14
server/poetry.lock generated
View File

@ -22,6 +22,14 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
test_trackers = ["comet-ml", "tensorboard", "wandb"]
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
[[package]]
name = "bitsandbytes"
version = "0.35.1"
description = "8-bit optimizers and matrix multiplication routines."
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "click"
version = "8.1.3"
@ -205,13 +213,17 @@ python-versions = ">=3.7"
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496"
content-hash = "50d9d44577a0222f125c770732d5f88807378573bd7386036eb5c79fc2a7c552"
[metadata.files]
accelerate = [
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
]
bitsandbytes = [
{file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"},
{file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"},
]
click = [
{file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"},
{file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},

View File

@ -15,6 +15,7 @@ typer = "^0.6.1"
grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0"
joblib = "^1.2.0"
bitsandbytes = "^0.35.1"
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1"