feat(server): Support bitsandbytes
This commit is contained in:
parent
beb552127a
commit
09674e6df9
|
@ -26,6 +26,7 @@ ENV LANG=C.UTF-8 \
|
||||||
DEBIAN_FRONTEND=noninteractive \
|
DEBIAN_FRONTEND=noninteractive \
|
||||||
MODEL_BASE_PATH=/var/azureml-model \
|
MODEL_BASE_PATH=/var/azureml-model \
|
||||||
MODEL_NAME=bigscience/bloom \
|
MODEL_NAME=bigscience/bloom \
|
||||||
|
QUANTIZE=false \
|
||||||
NUM_GPUS=8 \
|
NUM_GPUS=8 \
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
NCCL_ASYNC_ERROR_HANDLING=1 \
|
NCCL_ASYNC_ERROR_HANDLING=1 \
|
||||||
|
@ -72,4 +73,4 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS
|
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS
|
9
Makefile
9
Makefile
|
@ -18,5 +18,14 @@ router-dev:
|
||||||
run-bloom-560m:
|
run-bloom-560m:
|
||||||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
|
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
|
||||||
|
|
||||||
|
run-bloom-560m-quantize:
|
||||||
|
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
|
||||||
|
|
||||||
|
download-bloom:
|
||||||
|
bloom-inference-server download-weights bigscience/bloom
|
||||||
|
|
||||||
run-bloom:
|
run-bloom:
|
||||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
||||||
|
|
||||||
|
run-bloom-quantize:
|
||||||
|
text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize
|
42
README.md
42
README.md
|
@ -8,22 +8,26 @@
|
||||||
|
|
||||||
A Rust and gRPC server for large language models text generation inference.
|
A Rust and gRPC server for large language models text generation inference.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
|
- [Dynamic bathing of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
||||||
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
|
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||||
|
|
||||||
|
## Supported models
|
||||||
|
|
||||||
|
- BLOOM
|
||||||
|
- BLOOM-560m
|
||||||
|
|
||||||
## Load Tests for BLOOM
|
## Load Tests for BLOOM
|
||||||
|
|
||||||
See `k6/load_test.js`
|
See `k6/load_test.js`
|
||||||
We send the default examples with a 1 second delay between requests.
|
|
||||||
|
|
||||||
Stages:
|
|
||||||
- Ramp up to 50 vus in 1min
|
|
||||||
- Ramp up from 50 to 100 vus in 2min
|
|
||||||
- Ramp down to 0 vus in 1min
|
|
||||||
|
|
||||||
|
|
||||||
| | avg | min | med | max | p(90) | p(95) | RPS |
|
| | avg | min | med | max | p(90) | p(95) | RPS |
|
||||||
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
|
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
|
||||||
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
||||||
| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
|
| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
||||||
| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
@ -33,10 +37,30 @@ make install
|
||||||
|
|
||||||
## Run
|
## Run
|
||||||
|
|
||||||
|
### BLOOM 560-m
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-bloom-560m
|
make run-bloom-560m
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### BLOOM
|
||||||
|
|
||||||
|
First you need to download the weights:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make download-bloom
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make run-bloom # Requires 8xA100 80GB
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make run-bloom-quantize # Requires 8xA100 40GB
|
||||||
|
```
|
||||||
|
|
||||||
## Test
|
## Test
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|
|
@ -21,6 +21,8 @@ struct Args {
|
||||||
model_name: String,
|
model_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
quantize: bool,
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1000", long, env)]
|
||||||
|
@ -46,6 +48,7 @@ fn main() -> ExitCode {
|
||||||
let Args {
|
let Args {
|
||||||
model_name,
|
model_name,
|
||||||
num_shard,
|
num_shard,
|
||||||
|
quantize,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
@ -87,6 +90,7 @@ fn main() -> ExitCode {
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_name,
|
model_name,
|
||||||
|
quantize,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
num_shard,
|
num_shard,
|
||||||
|
@ -169,6 +173,8 @@ fn main() -> ExitCode {
|
||||||
tracing::error!("text-generation-router not found in PATH");
|
tracing::error!("text-generation-router not found in PATH");
|
||||||
tracing::error!("Please install it with `make install-router`")
|
tracing::error!("Please install it with `make install-router`")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
tracing::error!("{}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
@ -232,6 +238,7 @@ enum ShardStatus {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_name: String,
|
model_name: String,
|
||||||
|
quantize: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
world_size: usize,
|
world_size: usize,
|
||||||
|
@ -260,6 +267,10 @@ fn shard_manager(
|
||||||
shard_argv.push("--sharded".to_string());
|
shard_argv.push("--sharded".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if quantize {
|
||||||
|
shard_argv.push("--quantize".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
let mut env = vec![
|
let mut env = vec![
|
||||||
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
|
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
|
||||||
(
|
(
|
||||||
|
@ -338,11 +349,9 @@ fn shard_manager(
|
||||||
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
||||||
status_sender.send(ShardStatus::Ready).unwrap();
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
ready = true;
|
ready = true;
|
||||||
} else if !ready {
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||||
if wait_time.elapsed() > Duration::from_secs(10) {
|
tracing::info!("Waiting for shard {} to be ready...", rank);
|
||||||
tracing::info!("Waiting for shard {} to be ready...", rank);
|
wait_time = Instant::now();
|
||||||
wait_time = Instant::now();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,9 +12,7 @@ pub struct ShardedClient {
|
||||||
|
|
||||||
impl ShardedClient {
|
impl ShardedClient {
|
||||||
fn new(clients: Vec<Client>) -> Self {
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
Self {
|
Self { clients }
|
||||||
clients,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::GenerateRequest;
|
|
||||||
use crate::{Db, Entry};
|
use crate::{Db, Entry};
|
||||||
|
use crate::{ErrorResponse, GenerateRequest};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
|
use axum::Json;
|
||||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -213,10 +214,15 @@ pub enum InferError {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert to Axum supported format
|
/// Convert to Axum supported format
|
||||||
impl From<InferError> for (StatusCode, String) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
match err {
|
match err {
|
||||||
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
InferError::GenerationError(_) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
}),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,3 +64,8 @@ pub(crate) struct GenerateRequest {
|
||||||
pub(crate) struct GeneratedText {
|
pub(crate) struct GeneratedText {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub(crate) struct ErrorResponse {
|
||||||
|
pub error: String,
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation};
|
use crate::{
|
||||||
|
Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
||||||
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
|
@ -23,7 +25,7 @@ struct ServerState {
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead if check if the gRPC channels are still healthy.
|
||||||
|
@ -32,7 +34,9 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
(
|
(
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
"Model is overloaded".to_string(),
|
Json(ErrorResponse {
|
||||||
|
error: "Model is overloaded".to_string(),
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -70,13 +74,16 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
state: Extension<ServerState>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
|
tracing::error!("Model is overloaded");
|
||||||
(
|
(
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
"Model is overloaded".to_string(),
|
Json(ErrorResponse {
|
||||||
|
error: "Model is overloaded".to_string(),
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -88,10 +95,21 @@ async fn generate(
|
||||||
inputs: req.inputs.clone(),
|
inputs: req.inputs.clone(),
|
||||||
parameters: req.parameters.clone(),
|
parameters: req.parameters.clone(),
|
||||||
})
|
})
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let response = state.batcher.infer(input_length, validated_request).await?;
|
let response = state
|
||||||
|
.batcher
|
||||||
|
.infer(input_length, validated_request)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
// Timings
|
// Timings
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::GenerateRequest;
|
use crate::{ErrorResponse, GenerateRequest};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
|
use axum::Json;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::{
|
use tokenizers::{
|
||||||
|
@ -146,20 +147,25 @@ type ValidationRequest = (
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("Temperature must be strictly positive")]
|
#[error("temperature must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("Top p must be >= 0.0 or < 1.0")]
|
#[error("top_p must be >= 0.0 or < 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("Top k must be strictly positive")]
|
#[error("top_k must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
#[error("Max New Tokens must be <= 512")]
|
#[error("max_new_tokens must be <= 512")]
|
||||||
MaxNewTokens,
|
MaxNewTokens,
|
||||||
#[error("Inputs must have less than {1} tokens. Given: {0}")]
|
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ValidationError> for (StatusCode, String) {
|
impl From<ValidationError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: ValidationError) -> Self {
|
fn from(err: ValidationError) -> Self {
|
||||||
(StatusCode::BAD_REQUEST, err.to_string())
|
(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: err.to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ app = typer.Typer()
|
||||||
def serve(
|
def serve(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
|
quantize: bool = False,
|
||||||
uds_path: Path = "/tmp/bloom-inference",
|
uds_path: Path = "/tmp/bloom-inference",
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
|
@ -28,7 +29,7 @@ def serve(
|
||||||
os.getenv("MASTER_PORT", None) is not None
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
), "MASTER_PORT must be set when sharded is True"
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
server.serve(model_name, sharded, uds_path)
|
server.serve(model_name, sharded, quantize, uds_path)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|
|
@ -19,9 +19,16 @@ from bloom_inference.utils import (
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights
|
download_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
HAS_BITS_AND_BYTES = True
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from bitsandbytes.nn import Int8Params
|
||||||
|
except Exception as e:
|
||||||
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +75,9 @@ class Batch:
|
||||||
)
|
)
|
||||||
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
||||||
|
|
||||||
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8).to(device)
|
input_ids = tokenizer(
|
||||||
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||||
|
).to(device)
|
||||||
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -250,7 +259,7 @@ class BLOOM:
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Batch
|
self, batch: Batch
|
||||||
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
outputs = self.forward(**batch.input_ids)
|
outputs = self.forward(**batch.input_ids)
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
|
@ -374,13 +383,13 @@ class BLOOM:
|
||||||
|
|
||||||
|
|
||||||
class BLOOMSharded(BLOOM):
|
class BLOOMSharded(BLOOM):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, quantize: bool = False):
|
||||||
super(BLOOM, self).__init__()
|
super(BLOOM, self).__init__()
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.device = torch.device(f"cuda:{self.rank}")
|
self.device = torch.device(f"cuda:{self.rank}")
|
||||||
dtype = torch.bfloat16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -414,6 +423,7 @@ class BLOOMSharded(BLOOM):
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
model,
|
model,
|
||||||
filenames,
|
filenames,
|
||||||
|
quantize=quantize,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
|
@ -423,11 +433,18 @@ class BLOOMSharded(BLOOM):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model, filenames: List[str], device: torch.device, rank: int, world_size: int
|
model,
|
||||||
|
filenames: List[str],
|
||||||
|
quantize: bool,
|
||||||
|
device: torch.device,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
):
|
):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
with safe_open(
|
||||||
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
full_name = f"transformer.{name}"
|
full_name = f"transformer.{name}"
|
||||||
|
|
||||||
|
@ -479,6 +496,67 @@ class BLOOMSharded(BLOOM):
|
||||||
)
|
)
|
||||||
|
|
||||||
tensor = tensor.contiguous()
|
tensor = tensor.contiguous()
|
||||||
|
|
||||||
|
if quantize:
|
||||||
|
if not HAS_BITS_AND_BYTES:
|
||||||
|
raise ImportError(
|
||||||
|
"bitsandbytes is not available on your machine"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
type(module)
|
||||||
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||||
|
and param_name == "weight"
|
||||||
|
):
|
||||||
|
tensor = Int8Params(
|
||||||
|
tensor.transpose(1, 0),
|
||||||
|
has_fp16_weights=False,
|
||||||
|
requires_grad=False,
|
||||||
|
).to(device)
|
||||||
|
state = bnb.MatmulLtState()
|
||||||
|
state.threshold = 6.0
|
||||||
|
state.has_fp16_weights = False
|
||||||
|
state.memory_efficient_backward = False
|
||||||
|
state.use_pool = True
|
||||||
|
state.CB = tensor.CB
|
||||||
|
state.SCB = tensor.SCB
|
||||||
|
tensor.CB = None
|
||||||
|
tensor.SCB = None
|
||||||
|
|
||||||
|
def replace_linear(state, in_features, out_features):
|
||||||
|
def linear(input, weight, bias):
|
||||||
|
size_out = input.size()[:-1] + (out_features,)
|
||||||
|
input = input.view(-1, in_features)
|
||||||
|
out = torch.empty(
|
||||||
|
size_out, device=input.device, dtype=input.dtype
|
||||||
|
)
|
||||||
|
out = bnb.matmul(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
out=out.view(-1, out_features),
|
||||||
|
state=state,
|
||||||
|
threshold=state.threshold,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.CB is not None:
|
||||||
|
# we converted 8-bit row major to turing/ampere format
|
||||||
|
# in the first inference pass
|
||||||
|
# we no longer need the row-major weight
|
||||||
|
del state.CB
|
||||||
|
weight.data = state.CxB
|
||||||
|
|
||||||
|
return out.view(size_out)
|
||||||
|
|
||||||
|
return linear
|
||||||
|
|
||||||
|
module.linear = replace_linear(
|
||||||
|
state, module.in_features, module.out_features
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(device)
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
|
@ -68,21 +68,27 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def serve(
|
def serve(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
|
quantize: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
|
quantize: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
model = BLOOMSharded(model_name)
|
model = BLOOMSharded(model_name, quantize)
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(uds_path, rank)
|
unix_socket_template.format(uds_path, rank)
|
||||||
for rank in range(model.world_size)
|
for rank in range(model.world_size)
|
||||||
]
|
]
|
||||||
local_url = server_urls[model.rank]
|
local_url = server_urls[model.rank]
|
||||||
else:
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError(
|
||||||
|
"bitsandbytes quantization is only available when running in `sharded` mode."
|
||||||
|
)
|
||||||
model = BLOOM(model_name)
|
model = BLOOM(model_name)
|
||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
@ -105,4 +111,4 @@ def serve(
|
||||||
print("Signal received. Shutting down")
|
print("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_name, sharded))
|
asyncio.run(serve_inner(model_name, sharded, quantize))
|
||||||
|
|
|
@ -22,6 +22,14 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
|
||||||
test_trackers = ["comet-ml", "tensorboard", "wandb"]
|
test_trackers = ["comet-ml", "tensorboard", "wandb"]
|
||||||
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
|
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitsandbytes"
|
||||||
|
version = "0.35.1"
|
||||||
|
description = "8-bit optimizers and matrix multiplication routines."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "click"
|
name = "click"
|
||||||
version = "8.1.3"
|
version = "8.1.3"
|
||||||
|
@ -205,13 +213,17 @@ python-versions = ">=3.7"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496"
|
content-hash = "50d9d44577a0222f125c770732d5f88807378573bd7386036eb5c79fc2a7c552"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
accelerate = [
|
accelerate = [
|
||||||
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
|
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
|
||||||
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
|
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
|
||||||
]
|
]
|
||||||
|
bitsandbytes = [
|
||||||
|
{file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"},
|
||||||
|
{file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"},
|
||||||
|
]
|
||||||
click = [
|
click = [
|
||||||
{file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"},
|
{file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"},
|
||||||
{file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
|
{file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
|
||||||
|
|
|
@ -15,6 +15,7 @@ typer = "^0.6.1"
|
||||||
grpcio-reflection = "^1.49.1"
|
grpcio-reflection = "^1.49.1"
|
||||||
accelerate = "^0.12.0"
|
accelerate = "^0.12.0"
|
||||||
joblib = "^1.2.0"
|
joblib = "^1.2.0"
|
||||||
|
bitsandbytes = "^0.35.1"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
grpcio-tools = "^1.49.1"
|
grpcio-tools = "^1.49.1"
|
||||||
|
|
Loading…
Reference in New Issue