diff --git a/Dockerfile b/Dockerfile index 78f66c0..08561b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,6 +26,7 @@ ENV LANG=C.UTF-8 \ DEBIAN_FRONTEND=noninteractive \ MODEL_BASE_PATH=/var/azureml-model \ MODEL_NAME=bigscience/bloom \ + QUANTIZE=false \ NUM_GPUS=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ NCCL_ASYNC_ERROR_HANDLING=1 \ @@ -72,4 +73,4 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca # Install launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher -CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS \ No newline at end of file +CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS \ No newline at end of file diff --git a/Makefile b/Makefile index e22309f..a52ed2b 100644 --- a/Makefile +++ b/Makefile @@ -18,5 +18,14 @@ router-dev: run-bloom-560m: text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 +run-bloom-560m-quantize: + text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize + +download-bloom: + bloom-inference-server download-weights bigscience/bloom + run-bloom: text-generation-launcher --model-name bigscience/bloom --num-shard 8 + +run-bloom-quantize: + text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize \ No newline at end of file diff --git a/README.md b/README.md index 1af90dc..e51d8c8 100644 --- a/README.md +++ b/README.md @@ -8,22 +8,26 @@ A Rust and gRPC server for large language models text generation inference. +## Features + +- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) +- [Dynamic bathing of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput +- [Safetensors](https://github.com/huggingface/safetensors) weight loading +- 45ms per token generation for BLOOM with 8xA100 80GB + +## Supported models + +- BLOOM +- BLOOM-560m + ## Load Tests for BLOOM See `k6/load_test.js` -We send the default examples with a 1 second delay between requests. - -Stages: -- Ramp up to 50 vus in 1min -- Ramp up from 50 to 100 vus in 2min -- Ramp down to 0 vus in 1min - | | avg | min | med | max | p(90) | p(95) | RPS | |--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------| | [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 | -| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 | -| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | +| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | ## Install @@ -33,10 +37,30 @@ make install ## Run +### BLOOM 560-m + ```shell make run-bloom-560m ``` +### BLOOM + +First you need to download the weights: + +```shell +make download-bloom +``` + +```shell +make run-bloom # Requires 8xA100 80GB +``` + +You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: + +```shell +make run-bloom-quantize # Requires 8xA100 40GB +``` + ## Test ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 6b529d6..3e82516 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -21,6 +21,8 @@ struct Args { model_name: String, #[clap(long, env)] num_shard: Option, + #[clap(long, env)] + quantize: bool, #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, #[clap(default_value = "1000", long, env)] @@ -46,6 +48,7 @@ fn main() -> ExitCode { let Args { model_name, num_shard, + quantize, max_concurrent_requests, max_input_length, max_batch_size, @@ -87,6 +90,7 @@ fn main() -> ExitCode { thread::spawn(move || { shard_manager( model_name, + quantize, uds_path, rank, num_shard, @@ -169,6 +173,8 @@ fn main() -> ExitCode { tracing::error!("text-generation-router not found in PATH"); tracing::error!("Please install it with `make install-router`") } + } else { + tracing::error!("{}", err); } shutdown_shards(shutdown, &shutdown_receiver); @@ -232,6 +238,7 @@ enum ShardStatus { #[allow(clippy::too_many_arguments)] fn shard_manager( model_name: String, + quantize: bool, uds_path: String, rank: usize, world_size: usize, @@ -260,6 +267,10 @@ fn shard_manager( shard_argv.push("--sharded".to_string()); } + if quantize { + shard_argv.push("--quantize".to_string()) + } + let mut env = vec![ ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()), ( @@ -338,11 +349,9 @@ fn shard_manager( tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed()); status_sender.send(ShardStatus::Ready).unwrap(); ready = true; - } else if !ready { - if wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard {} to be ready...", rank); - wait_time = Instant::now(); - } + } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { + tracing::info!("Waiting for shard {} to be ready...", rank); + wait_time = Instant::now(); } sleep(Duration::from_millis(100)); } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6fddfd7..6c70afc 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -12,9 +12,7 @@ pub struct ShardedClient { impl ShardedClient { fn new(clients: Vec) -> Self { - Self { - clients, - } + Self { clients } } /// Create a new ShardedClient from a master client. The master client will communicate with diff --git a/router/src/batcher.rs b/router/src/batcher.rs index f71428e..f131bf9 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,7 +1,8 @@ /// Batching and inference logic -use crate::GenerateRequest; use crate::{Db, Entry}; +use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; +use axum::Json; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use std::future::Future; use std::sync::Arc; @@ -213,10 +214,15 @@ pub enum InferError { } /// Convert to Axum supported format -impl From for (StatusCode, String) { +impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { match err { - InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), + InferError::GenerationError(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: err.to_string(), + }), + ), } } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 6604a91..8646ad5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -64,3 +64,8 @@ pub(crate) struct GenerateRequest { pub(crate) struct GeneratedText { pub generated_text: String, } + +#[derive(Serialize)] +pub(crate) struct ErrorResponse { + pub error: String, +} diff --git a/router/src/server.rs b/router/src/server.rs index 02c4a49..5698f4e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,4 +1,6 @@ -use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation}; +use crate::{ + Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, +}; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; use axum::response::IntoResponse; @@ -23,7 +25,7 @@ struct ServerState { /// Health check method #[instrument(skip(state), fields(time, time_per_token))] -async fn health(state: Extension) -> Result<(), (StatusCode, String)> { +async fn health(state: Extension) -> Result<(), (StatusCode, Json)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. // What we should do instead if check if the gRPC channels are still healthy. @@ -32,7 +34,9 @@ async fn health(state: Extension) -> Result<(), (StatusCode, String let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { ( StatusCode::TOO_MANY_REQUESTS, - "Model is overloaded".to_string(), + Json(ErrorResponse { + error: "Model is overloaded".to_string(), + }), ) })?; @@ -70,13 +74,16 @@ async fn health(state: Extension) -> Result<(), (StatusCode, String async fn generate( state: Extension, req: Json, -) -> Result { +) -> Result)> { let start_time = Instant::now(); // Limit concurrent requests by acquiring a permit from the semaphore let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + tracing::error!("Model is overloaded"); ( StatusCode::TOO_MANY_REQUESTS, - "Model is overloaded".to_string(), + Json(ErrorResponse { + error: "Model is overloaded".to_string(), + }), ) })?; @@ -88,10 +95,21 @@ async fn generate( inputs: req.inputs.clone(), parameters: req.parameters.clone(), }) - .await?; + .await + .map_err(|err| { + tracing::error!("{}", err.to_string()); + err + })?; // Inference - let response = state.batcher.infer(input_length, validated_request).await?; + let response = state + .batcher + .infer(input_length, validated_request) + .await + .map_err(|err| { + tracing::error!("{}", err.to_string()); + err + })?; // Timings let total_time = start_time.elapsed(); diff --git a/router/src/validation.rs b/router/src/validation.rs index 11712d0..43c246f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,6 +1,7 @@ /// Payload validation logic -use crate::GenerateRequest; +use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; +use axum::Json; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::{ @@ -146,20 +147,25 @@ type ValidationRequest = ( #[derive(Error, Debug)] pub enum ValidationError { - #[error("Temperature must be strictly positive")] + #[error("temperature must be strictly positive")] Temperature, - #[error("Top p must be >= 0.0 or < 1.0")] + #[error("top_p must be >= 0.0 or < 1.0")] TopP, - #[error("Top k must be strictly positive")] + #[error("top_k must be strictly positive")] TopK, - #[error("Max New Tokens must be <= 512")] + #[error("max_new_tokens must be <= 512")] MaxNewTokens, - #[error("Inputs must have less than {1} tokens. Given: {0}")] + #[error("inputs must have less than {1} tokens. Given: {0}")] InputLength(usize, usize), } -impl From for (StatusCode, String) { +impl From for (StatusCode, Json) { fn from(err: ValidationError) -> Self { - (StatusCode::BAD_REQUEST, err.to_string()) + ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: err.to_string(), + }), + ) } } diff --git a/server/bloom_inference/cli.py b/server/bloom_inference/cli.py index 600c527..6360aec 100644 --- a/server/bloom_inference/cli.py +++ b/server/bloom_inference/cli.py @@ -12,6 +12,7 @@ app = typer.Typer() def serve( model_name: str, sharded: bool = False, + quantize: bool = False, uds_path: Path = "/tmp/bloom-inference", ): if sharded: @@ -28,7 +29,7 @@ def serve( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" - server.serve(model_name, sharded, uds_path) + server.serve(model_name, sharded, quantize, uds_path) @app.command() diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index eac5400..a045005 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -19,9 +19,16 @@ from bloom_inference.utils import ( NextTokenChooser, initialize_torch_distributed, weight_files, - download_weights + download_weights, ) +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except Exception as e: + HAS_BITS_AND_BYTES = False + torch.manual_seed(0) @@ -68,7 +75,9 @@ class Batch: ) stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) - input_ids = tokenizer(inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8).to(device) + input_ids = tokenizer( + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + ).to(device) all_input_ids = input_ids["input_ids"].unsqueeze(-1) return cls( @@ -250,7 +259,7 @@ class BLOOM: def generate_token( self, batch: Batch ) -> Tuple[List[GeneratedText], Optional[Batch]]: - with torch.no_grad(): + with torch.inference_mode(): outputs = self.forward(**batch.input_ids) # List of indices to cache @@ -374,13 +383,13 @@ class BLOOM: class BLOOMSharded(BLOOM): - def __init__(self, model_name: str): + def __init__(self, model_name: str, quantize: bool = False): super(BLOOM, self).__init__() self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): self.device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.float16 else: self.device = torch.device("cpu") dtype = torch.float32 @@ -414,6 +423,7 @@ class BLOOMSharded(BLOOM): self.load_weights( model, filenames, + quantize=quantize, device=self.device, rank=self.rank, world_size=self.world_size, @@ -423,11 +433,18 @@ class BLOOMSharded(BLOOM): @staticmethod def load_weights( - model, filenames: List[str], device: torch.device, rank: int, world_size: int + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: - with safe_open(file, framework="pt", device=str(device)) as f: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: for name in f.keys(): full_name = f"transformer.{name}" @@ -479,6 +496,67 @@ class BLOOMSharded(BLOOM): ) tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine" + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor.transpose(1, 0), + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state, in_features, out_features): + def linear(input, weight, bias): + size_out = input.size()[:-1] + (out_features,) + input = input.view(-1, in_features) + out = torch.empty( + size_out, device=input.device, dtype=input.dtype + ) + out = bnb.matmul( + input, + weight, + out=out.view(-1, out_features), + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out.view(size_out) + + return linear + + module.linear = replace_linear( + state, module.in_features, module.out_features + ) + + else: + tensor = tensor.to(device) + module._parameters[param_name] = tensor if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor diff --git a/server/bloom_inference/server.py b/server/bloom_inference/server.py index fb02db6..ad40a52 100644 --- a/server/bloom_inference/server.py +++ b/server/bloom_inference/server.py @@ -68,21 +68,27 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_name: str, sharded: bool, + quantize: bool, uds_path: Path, ): async def serve_inner( model_name: str, sharded: bool = False, + quantize: bool = False, ): unix_socket_template = "unix://{}-{}" if sharded: - model = BLOOMSharded(model_name) + model = BLOOMSharded(model_name, quantize) server_urls = [ unix_socket_template.format(uds_path, rank) for rank in range(model.world_size) ] local_url = server_urls[model.rank] else: + if quantize: + raise ValueError( + "bitsandbytes quantization is only available when running in `sharded` mode." + ) model = BLOOM(model_name) local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] @@ -105,4 +111,4 @@ def serve( print("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_name, sharded)) + asyncio.run(serve_inner(model_name, sharded, quantize)) diff --git a/server/poetry.lock b/server/poetry.lock index 38f14d3..5e635a6 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -22,6 +22,14 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] test_trackers = ["comet-ml", "tensorboard", "wandb"] testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"] +[[package]] +name = "bitsandbytes" +version = "0.35.1" +description = "8-bit optimizers and matrix multiplication routines." +category = "main" +optional = false +python-versions = "*" + [[package]] name = "click" version = "8.1.3" @@ -205,13 +213,17 @@ python-versions = ">=3.7" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496" +content-hash = "50d9d44577a0222f125c770732d5f88807378573bd7386036eb5c79fc2a7c552" [metadata.files] accelerate = [ {file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"}, {file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"}, ] +bitsandbytes = [ + {file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"}, + {file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"}, +] click = [ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index 3d38f51..4e5e98b 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -15,6 +15,7 @@ typer = "^0.6.1" grpcio-reflection = "^1.49.1" accelerate = "^0.12.0" joblib = "^1.2.0" +bitsandbytes = "^0.35.1" [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.49.1"