feat(server): Support all AutoModelForCausalLM on a best effort basis

This commit is contained in:
OlivierDehaene 2022-10-28 19:24:00 +02:00
parent 09674e6df9
commit 3cf6368c77
26 changed files with 755 additions and 710 deletions

156
Cargo.lock generated
View File

@ -28,9 +28,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.65"
version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
[[package]]
name = "async-stream"
@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.16"
version = "0.5.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
dependencies = [
"async-trait",
"axum-core",
@ -114,9 +114,9 @@ dependencies = [
[[package]]
name = "axum-core"
version = "0.2.8"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
dependencies = [
"async-trait",
"bytes",
@ -130,9 +130,9 @@ dependencies = [
[[package]]
name = "base64"
version = "0.13.0"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "bitflags"
@ -149,21 +149,6 @@ dependencies = [
"generic-array",
]
[[package]]
name = "bloom-inference-client"
version = "0.1.0"
dependencies = [
"futures",
"prost",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tower",
"tracing",
"tracing-error",
]
[[package]]
name = "bumpalo"
version = "3.11.1"
@ -255,9 +240,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.0.17"
version = "4.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b"
dependencies = [
"atty",
"bitflags",
@ -270,9 +255,9 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.0.13"
version = "4.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3"
dependencies = [
"heck 0.4.0",
"proc-macro-error",
@ -532,14 +517,14 @@ dependencies = [
[[package]]
name = "filetime"
version = "0.2.17"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c"
checksum = "4b9663d381d07ae25dc88dbdf27df458faa83a9b25336bcac83d5e452b5fc9d3"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"windows-sys 0.36.1",
"windows-sys 0.42.0",
]
[[package]]
@ -600,9 +585,9 @@ dependencies = [
[[package]]
name = "futures"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0"
dependencies = [
"futures-channel",
"futures-core",
@ -615,9 +600,9 @@ dependencies = [
[[package]]
name = "futures-channel"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050"
checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed"
dependencies = [
"futures-core",
"futures-sink",
@ -625,15 +610,15 @@ dependencies = [
[[package]]
name = "futures-core"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac"
[[package]]
name = "futures-executor"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab"
checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2"
dependencies = [
"futures-core",
"futures-task",
@ -642,15 +627,15 @@ dependencies = [
[[package]]
name = "futures-io"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68"
checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb"
[[package]]
name = "futures-macro"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17"
checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d"
dependencies = [
"proc-macro2",
"quote",
@ -659,21 +644,21 @@ dependencies = [
[[package]]
name = "futures-sink"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56"
checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9"
[[package]]
name = "futures-task"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1"
checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea"
[[package]]
name = "futures-util"
version = "0.3.24"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90"
checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6"
dependencies = [
"futures-channel",
"futures-core",
@ -699,9 +684,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.7"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [
"cfg-if",
"libc",
@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "h2"
version = "0.3.14"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be"
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
dependencies = [
"bytes",
"fnv",
@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.135"
version = "0.2.137"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
[[package]]
name = "lock_api"
@ -992,9 +977,9 @@ dependencies = [
[[package]]
name = "macro_rules_attribute"
version = "0.1.2"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "258c86475e1616d6f2d8f5227cfaabd3dae1f6d5388b9597df8a199d4497aba7"
checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
dependencies = [
"macro_rules_attribute-proc_macro",
"paste",
@ -1002,9 +987,9 @@ dependencies = [
[[package]]
name = "macro_rules_attribute-proc_macro"
version = "0.1.2"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "matchit"
@ -1050,14 +1035,14 @@ dependencies = [
[[package]]
name = "mio"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf"
checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de"
dependencies = [
"libc",
"log",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.36.1",
"windows-sys 0.42.0",
]
[[package]]
@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-sys"
version = "0.9.76"
version = "0.9.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce"
checksum = "b03b84c3b2d099b81f0953422b4d4ad58761589d0229b5506356afca05a3670a"
dependencies = [
"autocfg",
"cc",
@ -1213,9 +1198,9 @@ dependencies = [
[[package]]
name = "os_str_bytes"
version = "6.3.0"
version = "6.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
checksum = "3baf96e39c5359d2eb0dd6ccb42c62b91d9678aa68160d261b9e0ccbf9e9dea9"
[[package]]
name = "overload"
@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkg-config"
version = "0.3.25"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "ppv-lite86"
@ -1602,18 +1587,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.145"
version = "1.0.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b"
checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.145"
version = "1.0.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c"
checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
dependencies = [
"proc-macro2",
"quote",
@ -1622,9 +1607,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.86"
version = "1.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
dependencies = [
"itoa",
"ryu",
@ -1739,9 +1724,9 @@ dependencies = [
[[package]]
name = "syn"
version = "1.0.102"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
dependencies = [
"proc-macro2",
"quote",
@ -1798,11 +1783,26 @@ dependencies = [
"winapi",
]
[[package]]
name = "text-generation-client"
version = "0.1.0"
dependencies = [
"futures",
"prost",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tower",
"tracing",
"tracing-error",
]
[[package]]
name = "text-generation-launcher"
version = "0.1.0"
dependencies = [
"clap 4.0.17",
"clap 4.0.18",
"ctrlc",
"subprocess",
"tracing",
@ -1814,12 +1814,12 @@ name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"clap 4.0.17",
"clap 4.0.18",
"futures",
"parking_lot",
"serde",
"serde_json",
"text-generation-client",
"thiserror",
"tokenizers",
"tokio",

View File

@ -66,7 +66,7 @@ COPY proto proto
COPY server server
RUN cd server && \
make gen-server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
# Install router
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router

View File

@ -22,7 +22,7 @@ run-bloom-560m-quantize:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
download-bloom:
bloom-inference-server download-weights bigscience/bloom
text-generation-server download-weights bigscience/bloom
run-bloom:
text-generation-launcher --model-name bigscience/bloom --num-shard 8

View File

@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference.
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB
## Supported models
## Officially supported models
- BLOOM
- BLOOM-560m
Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`.
## Load Tests for BLOOM
See `k6/load_test.js`

View File

@ -1,5 +1,5 @@
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
name: bloom
name: bloom-safetensors
version: 1
path: ./bloom
path: ./bloom-safetensors
type: custom_model

View File

@ -256,7 +256,7 @@ fn shard_manager(
// Process args
let mut shard_argv = vec![
"bloom-inference-server".to_string(),
"text-generation-server".to_string(),
"serve".to_string(),
model_name,
"--uds-path".to_string(),
@ -311,7 +311,7 @@ fn shard_manager(
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("bloom-inference-server not found in PATH");
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}

View File

@ -14,7 +14,7 @@ path = "src/main.rs"
[dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" }
text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
parking_lot = "0.12.1"

View File

@ -1,5 +1,5 @@
[package]
name = "bloom-inference-client"
name = "text-generation-client"
version = "0.1.0"
edition = "2021"

View File

@ -3,9 +3,9 @@ use crate::{Db, Entry};
use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode;
use axum::Json;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
use tokio::time::Instant;

View File

@ -1,10 +1,10 @@
use crate::InferResponse;
/// This code is massively inspired by Tokio mini-redis
use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;

View File

@ -1,7 +1,7 @@
/// Text Generation Inference webserver entrypoint
use bloom_inference_client::ShardedClient;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// Text Generation Inference webserver entrypoint
use text_generation_client::ShardedClient;
use text_generation_router::server;
use tokenizers::Tokenizer;
@ -19,7 +19,7 @@ struct Args {
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
#[clap(default_value = "/tmp/text-generation-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,

View File

@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use bloom_inference_client::ShardedClient;
use std::net::SocketAddr;
use std::sync::Arc;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::Semaphore;

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
bloom_inference/__pycache__/
bloom_inference/pb/__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
*.py[cod]
*$py.class

View File

@ -1,10 +1,10 @@
gen-server:
# Compile protos
pip install grpcio-tools==1.49.1 --no-cache-dir
mkdir bloom_inference/pb || true
python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch bloom_inference/pb/__init__.py
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py
install-transformers:
# Install specific version of transformers
@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors
pip install -e . --no-cache-dir
run-dev:
python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded
python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded

View File

@ -1,582 +0,0 @@
import torch
import torch.distributed
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from bloom_inference.pb import generate_pb2
from bloom_inference.utils import (
StoppingCriteria,
NextTokenChooser,
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
@dataclass
class Batch:
batch_id: int
requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
size: int
max_sequence_length: int
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
max_sequence_length=self.max_sequence_length,
)
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch":
inputs = []
next_token_choosers = []
stopping_criterias = []
all_input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
all_input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=pb.max_sequence_length,
)
@classmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
# Reshape the tensors to make slicing easier
past_keys = past_keys.view(
batch.size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch.size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
# Initialize tensors
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
# If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
@dataclass
class GeneratedText:
request: generate_pb2.Request
output: str
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(request=self.request, output=self.output)
class BLOOM:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.bfloat16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = (
AutoModelForCausalLM.from_pretrained(model_name)
.eval()
.to(self.device)
.to(dtype)
)
self.num_heads = self.model.base_model.num_heads
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: Batch
) -> Tuple[List[GeneratedText], Optional[Batch]]:
with torch.inference_mode():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
next_batch_past_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output))
# add to the next batch
else:
next_batch_keep_indices.append(i)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices.extend(
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_input_ids["past_key_values"] = [
(
keys[next_batch_past_keep_indices],
values[next_batch_past_keep_indices],
)
for keys, values in outputs["past_key_values"]
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Batch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False):
super(BLOOM, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True
)
config.pad_token_id = 3
self.num_heads = config.n_head // self.process_group.size()
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m":
download_weights(model_name)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_name)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=self.device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
if param_name == "weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine"
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor.transpose(1, 0),
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state, in_features, out_features):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = torch.empty(
size_out, device=input.device, dtype=input.dtype
)
out = bnb.matmul(
input,
weight,
out=out.view(-1, out_features),
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out.view(size_out)
return linear
module.linear = replace_linear(
state, module.in_features, module.out_features
)
else:
tensor = tensor.to(device)
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
outputs.logits = logits
return outputs

View File

@ -1,11 +1,11 @@
[tool.poetry]
name = "bloom-inference"
name = "text-generation"
version = "0.1.0"
description = "BLOOM Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
bloom-inference-server = 'bloom_inference.cli:app'
text-generation-server = 'text_generation.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
@ -17,6 +17,9 @@ accelerate = "^0.12.0"
joblib = "^1.2.0"
bitsandbytes = "^0.35.1"
[tool.poetry.extras]
bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1"

View File

@ -1,6 +1,7 @@
from bloom_inference.model import Batch
from typing import Dict, Optional
from text_generation.models.types import Batch
class Cache:
def __init__(self):

View File

@ -3,7 +3,7 @@ import typer
from pathlib import Path
from bloom_inference import server, utils
from text_generation import server, utils
app = typer.Typer()
@ -13,7 +13,7 @@ def serve(
model_name: str,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/bloom-inference",
uds_path: Path = "/tmp/text-generation",
):
if sharded:
assert (
@ -35,8 +35,9 @@ def serve(
@app.command()
def download_weights(
model_name: str,
extension: str = ".safetensors",
):
utils.download_weights(model_name)
utils.download_weights(model_name, extension)
if __name__ == "__main__":

View File

@ -0,0 +1,22 @@
from text_generation.models.model import Model
from text_generation.models.bloom import BLOOMSharded
__all__ = ["Model", "BLOOMSharded"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if model_name.startswith("bigscience/bloom"):
if sharded:
return BLOOMSharded(model_name, quantize)
else:
if quantize:
raise ValueError("quantization is not supported for non-sharded BLOOM")
return Model(model_name)
else:
if sharded:
raise ValueError("sharded is only supported for BLOOM")
if quantize:
raise ValueError("Quantization is only supported for BLOOM models")
return Model(model_name)

View File

@ -0,0 +1,231 @@
import torch
import torch.distributed
from typing import List, Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation.models import Model
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
torch.manual_seed(0)
class BLOOMSharded(Model):
def __init__(self, model_name: str, quantize: bool = False):
super(Model, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True
)
config.pad_token_id = 3
self.num_heads = config.n_head // self.process_group.size()
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m":
download_weights(model_name, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_name, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=self.device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
if param_name == "weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor.transpose(1, 0),
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state, in_features, out_features):
def linear(input, weight, bias):
size_out = input.size()[:-1] + (out_features,)
input = input.view(-1, in_features)
out = torch.empty(
size_out, device=input.device, dtype=input.dtype
)
out = bnb.matmul(
input,
weight,
out=out.view(-1, out_features),
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out.view(size_out)
return linear
module.linear = replace_linear(
state, module.in_features, module.out_features
)
else:
tensor = tensor.to(device)
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
outputs.logits = logits
return outputs

View File

@ -0,0 +1,166 @@
import torch
import torch.distributed
from typing import List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation.models.types import Batch, GeneratedText
class Model:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto"
).eval()
self.num_heads = self.model.config.num_attention_heads
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> CausalLMOutputWithPast:
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: Batch
) -> Tuple[List[GeneratedText], Optional[Batch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
next_batch_past_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output))
# add to the next batch
else:
next_batch_keep_indices.append(i)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices.extend(
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_input_ids["past_key_values"] = [
(
keys[next_batch_past_keep_indices],
values[next_batch_past_keep_indices],
)
for keys, values in outputs["past_key_values"]
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Batch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch

View File

@ -0,0 +1,206 @@
import torch
from dataclasses import dataclass
from typing import List, Dict
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass
class Batch:
batch_id: int
requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
size: int
max_sequence_length: int
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
max_sequence_length=self.max_sequence_length,
)
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch":
inputs = []
next_token_choosers = []
stopping_criterias = []
all_input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
all_input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=pb.max_sequence_length,
)
@classmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
# Reshape the tensors to make slicing easier
past_keys = past_keys.view(
batch.size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch.size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
# Initialize tensors
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
# If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
@dataclass
class GeneratedText:
request: generate_pb2.Request
output: str
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(request=self.request, output=self.output)

View File

@ -5,15 +5,16 @@ from grpc import aio
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional, List
from typing import List
from bloom_inference.cache import Cache
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
from text_generation.cache import Cache
from text_generation.models import Model, get_model
from text_generation.models.types import Batch
from text_generation.pb import generate_pb2_grpc, generate_pb2
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
self.model = model
self.server_urls = server_urls
@ -78,21 +79,17 @@ def serve(
):
unix_socket_template = "unix://{}-{}"
if sharded:
model = BLOOMSharded(model_name, quantize)
server_urls = [
unix_socket_template.format(uds_path, rank)
for rank in range(model.world_size)
for rank in range(int(os.environ["WORLD_SIZE"]))
]
local_url = server_urls[model.rank]
local_url = server_urls[int(os.environ["RANK"])]
else:
if quantize:
raise ValueError(
"bitsandbytes quantization is only available when running in `sharded` mode."
)
model = BLOOM(model_name)
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
model = get_model(model_name, sharded, quantize)
server = aio.server()
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server

View File

@ -92,19 +92,17 @@ def initialize_torch_distributed():
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
def weight_hub_files(model_name):
def weight_hub_files(model_name, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
info = api.model_info(model_name)
filenames = [
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
]
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
def weight_files(model_name):
def weight_files(model_name, extension=".safetensors"):
"""Get the local safetensors filenames"""
filenames = weight_hub_files(model_name)
filenames = weight_hub_files(model_name, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(model_name, filename=filename)
@ -112,16 +110,16 @@ def weight_files(model_name):
raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `bloom-inference-server download-weights {model_name}` first."
f"Please run `text-generation-server download-weights {model_name}` first."
)
files.append(cache_file)
return files
def download_weights(model_name):
def download_weights(model_name, extension=".safetensors"):
"""Download the safetensors files from the hub"""
filenames = weight_hub_files(model_name)
filenames = weight_hub_files(model_name, extension)
download_function = partial(
hf_hub_download, repo_id=model_name, local_files_only=False