feat(server): add paged attention to flash models (#516)

Closes #478
This commit is contained in:
OlivierDehaene 2023-06-30 19:09:59 +02:00 committed by GitHub
parent 70f485bf9f
commit e74bd41e0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1045 additions and 888 deletions

View File

@ -88,7 +88,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder
@ -109,6 +108,16 @@ COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm
# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -137,9 +146,12 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -43,8 +43,8 @@ to power LLMs api-inference widgets.
- Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE)
- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))

View File

@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot):
response = await flash_neox.generate(
@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
responses = await generate_load(

View File

@ -115,12 +115,6 @@ struct Args {
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
/// The maximum allowed batch size during dynamic batching.
/// Using `max_batch_total_tokens` should be favored in general
/// as it's a finer way to control RAM usage.
#[clap(long, env)]
max_batch_size: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting
/// ones into the same batch.
@ -134,6 +128,12 @@ struct Args {
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware.
///
@ -146,19 +146,12 @@ struct Args {
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
/// or a single query of `1000` tokens.
///
/// So you don't have to control that finely
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
/// want maximum flexibility. However, for your users if they are asking for the full amount of
/// total tokens, they are likely to wait for a very long time to get a spot
/// in the batch (since they are going to be alone) so setting `max_batch_size`
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
///
/// Overall this number should be the largest possible amount that fits the
/// remaining memory (after the model is loaded). Since the actual memory overhead
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "32000", long, env)]
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
/// This setting defines how many tokens can be passed before forcing the waiting
@ -180,9 +173,9 @@ struct Args {
/// for end users.
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
/// The port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16,
/// The name of the socket for gRPC communication between the webserver
@ -329,6 +322,12 @@ fn shard_manager(
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -446,7 +445,7 @@ fn shard_manager(
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated");
return;
@ -822,6 +821,10 @@ fn spawn_webserver(
args.max_input_length.to_string(),
"--max-total-tokens".to_string(),
args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
@ -834,15 +837,6 @@ fn spawn_webserver(
args.model_id,
];
// Deprecate max_batch_size
if let Some(max_batch_size) = args.max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(args.max_batch_total_tokens.to_string())
}
// Model optional revision
if let Some(ref revision) = args.revision {
argv.push("--revision".to_string());

View File

@ -11,6 +11,8 @@ service TextGenerationService {
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
@ -192,3 +194,13 @@ message DecodeResponse {
/// Next batch (cached)
optional CachedBatch batch = 2;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}
/// Empty response
message WarmupResponse {}

View File

@ -3,6 +3,7 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServi
use crate::pb::generate::v1::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
@ -94,6 +95,63 @@ impl Client {
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
stop_sequences: vec![],
ignore_eos_token: false,
}),
prefill_logprobs: true,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch

View File

@ -87,6 +87,27 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch

View File

@ -45,6 +45,7 @@ impl Infer {
client: ShardedClient,
validation: Validation,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
@ -61,6 +62,7 @@ impl Infer {
tokio::spawn(batching_task(
client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
queue.clone(),
@ -240,9 +242,11 @@ impl Infer {
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
queue: Queue,
@ -257,8 +261,9 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await
while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
@ -284,11 +289,12 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens - batch_max_tokens;
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) =
queue.next_batch(min_size, token_budget).await
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {

View File

@ -32,10 +32,10 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
waiting_served_ratio,
mut max_batch_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
port,
master_shard_uds_path,
@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size {
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() {
tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
@ -161,9 +155,15 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
false => get_model_info(&tokenizer_name, &revision, authorization_token)
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}),
};
@ -190,6 +190,17 @@ fn main() -> Result<(), std::io::Error> {
.info()
.await
.expect("Unable to get shard info");
// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
tracing::info!("Connected");
// Binds on localhost
@ -206,6 +217,7 @@ fn main() -> Result<(), std::io::Error> {
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
sharded_client,

View File

@ -58,6 +58,7 @@ impl Queue {
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
// Create response channel
@ -67,6 +68,7 @@ impl Queue {
self.queue_sender
.send(QueueCommand::NextBatch {
min_size,
prefill_token_budget,
token_budget,
response_sender,
span: Span::current(),
@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
}
QueueCommand::NextBatch {
min_size,
prefill_token_budget,
token_budget,
response_sender,
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget);
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}),
@ -140,7 +143,12 @@ impl State {
}
// Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
fn next_batch(
&mut self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() {
return None;
}
@ -184,7 +192,9 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget {
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
{
// Entry is over budget
// Add it back to the front
self.entries.push_front((id, entry));
@ -259,6 +269,7 @@ enum QueueCommand {
Append(Box<Entry>, Span),
NextBatch {
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
@ -328,8 +339,8 @@ mod tests {
fn test_next_batch_empty() {
let mut state = State::new(false);
assert!(state.next_batch(None, 1).is_none());
assert!(state.next_batch(Some(1), 1).is_none());
assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
}
#[test]
@ -340,7 +351,7 @@ mod tests {
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, 2, 2).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
@ -356,7 +367,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none());
assert!(state.next_batch(Some(2), 2, 2).is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
@ -372,7 +383,7 @@ mod tests {
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
let (entries, batch, _) = state.next_batch(None, 1, 1).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
@ -385,7 +396,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
let (entries, batch, _) = state.next_batch(None, 3, 3).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
@ -408,8 +419,8 @@ mod tests {
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false);
assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1).await.is_none());
assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
}
#[tokio::test]
@ -420,7 +431,7 @@ mod tests {
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
@ -433,11 +444,11 @@ mod tests {
queue.append(entry3);
// Not enough requests pending
assert!(queue.next_batch(Some(2), 2).await.is_none());
assert!(queue.next_batch(Some(2), 2, 2).await.is_none());
// Not enough token budget
assert!(queue.next_batch(Some(1), 0).await.is_none());
assert!(queue.next_batch(Some(1), 0, 0).await.is_none());
// Ok
let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap();
let (entries2, batch2, _) = queue.next_batch(Some(1), 2, 2).await.unwrap();
assert_eq!(entries2.len(), 1);
assert!(entries2.contains_key(&2));
assert!(entries2.get(&2).unwrap().batch_time.is_some());
@ -453,7 +464,7 @@ mod tests {
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
@ -462,7 +473,7 @@ mod tests {
let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
@ -476,6 +487,6 @@ mod tests {
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(None, 1, 1).await.is_none());
}
}

View File

@ -514,6 +514,7 @@ pub async fn run(
max_input_length: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
client: ShardedClient,
@ -582,6 +583,7 @@ pub async fn run(
client,
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,

13
server/Makefile-vllm Normal file
View File

@ -0,0 +1,13 @@
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)
cd vllm && python setup.py build
install-vllm: build-vllm
pip uninstall vllm -y || true
cd vllm && python setup.py install

View File

@ -22,7 +22,9 @@ class Cache:
del batch
def clear(self):
self.cache.clear()
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())

View File

@ -122,7 +122,7 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,

View File

@ -23,12 +23,16 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -106,7 +110,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5)
self.softmax_scale = self.head_size**-0.5
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
@ -122,20 +126,22 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights,
bias=False,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -144,23 +150,25 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
# output
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -173,31 +181,19 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -265,14 +261,13 @@ class FlashLlamaLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -281,14 +276,13 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
# faster post attention rms norm
@ -333,40 +327,18 @@ class FlashLlamaModel(torch.nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -380,34 +352,18 @@ class FlashLlamaModel(torch.nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashLlamaForCausalLM(torch.nn.Module):
@ -423,31 +379,29 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.model(
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits

View File

@ -25,11 +25,15 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -110,20 +114,22 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -132,23 +138,25 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
# output
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -161,31 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -250,14 +246,13 @@ class FlashNeoXLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -266,14 +261,13 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -292,14 +286,13 @@ class FlashNeoXLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
@ -346,40 +339,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
@ -393,34 +364,18 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
@ -434,31 +389,29 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.gpt_neox(
) -> torch.Tensor:
hidden_states = self.gpt_neox(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
return logits, present
return logits

View File

@ -4,11 +4,15 @@ import torch.distributed
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -126,19 +130,27 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
@ -156,25 +168,29 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
vllm_cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -187,32 +203,19 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -264,19 +267,22 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward(
self,
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -293,10 +299,19 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
)
# output
attn_output = torch.empty_like(query)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = kv
if start_seq_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
@ -304,18 +319,16 @@ class FlashRWLargeAttention(torch.nn.Module):
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -328,36 +341,19 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.dense(
@ -432,14 +428,13 @@ class FlashRWLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -448,14 +443,13 @@ class FlashRWLayer(nn.Module):
ln_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
mlp_output = self.mlp(ln_hidden_states)
@ -472,14 +466,13 @@ class FlashRWLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
@ -523,14 +516,13 @@ class FlashRWLargeLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
@ -540,14 +532,13 @@ class FlashRWLargeLayer(nn.Module):
ln_attn,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
# MLP.
@ -580,11 +571,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
self.cache_size = self.h[0].self_attention.num_heads_kv
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
@ -592,11 +579,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
self.cache_size = self.h[0].self_attention.num_groups
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
@ -612,38 +595,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
*self.cache_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
@ -657,32 +620,18 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashRWForCausalLM(FlashRWPreTrainedModel):
@ -697,31 +646,29 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits

View File

@ -3,11 +3,15 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -221,18 +225,20 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.c_attn(hidden_states)
@ -245,25 +251,28 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = key_value
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
vllm_cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -276,32 +285,19 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -361,27 +357,25 @@ class Block(nn.Module):
self,
hidden_states,
residual,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
@ -427,64 +421,38 @@ class FlashSantacoderModel(nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
# Decode
else:
prefill = False
residual = None
for i, layer in enumerate(self.h):
hidden_states, residual = layer(
hidden_states,
residual,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashSantacoderForCausalLM(nn.Module):
@ -497,31 +465,29 @@ class FlashSantacoderForCausalLM(nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits

View File

@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False

View File

@ -1,11 +1,14 @@
import math
import itertools
import torch
import torch.distributed
import numpy as np
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
@ -20,6 +23,92 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = BLOCK_SIZE
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size)
def allocate(self, batch: "FlashCausalLMBatch"):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
assert (
len(free_block_indices) >= batch.blocks
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
# Slice by the number of required blocks
block_indices = free_block_indices[: batch.blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(batch), batch.max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
cumulative_blocks += needed_blocks
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
batch.slots = torch.concat(slots).to(batch.input_ids.device)
# Allocate the required number of blocks by setting the mask to 0
self.free_block_mask[block_indices] = 0
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None and block_indices:
# Reset mask
self.free_block_mask[block_indices] = 1
@dataclass
class FlashCausalLMBatch(Batch):
@ -32,23 +121,29 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor
position_ids: torch.Tensor
# Indices to copy present to the correct indices is the pre-allocated past key values
past_present_indices: torch.Tensor
# tensor of length b holding starting offset of each sequence
start_seq: torch.Tensor
# tensor of length b holding ending offset of each sequence
end_seq: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding starting offset of each query sequence, only used in decode
start_seq_q: Optional[torch.Tensor]
# tensor of length b holding ending offset of each query sequence, only used in decode
end_seq_q: Optional[torch.Tensor]
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
# Paged Attention values
# Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots: Optional[List[Tuple[int, int]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor]
max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
@ -62,6 +157,7 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
input_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
@ -69,15 +165,17 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
max_tokens: int
# Number of blocks in this batch
blocks: int
# Maximum number of blocks
max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
max_tokens=self.blocks * BLOCK_SIZE,
)
@classmethod
@ -99,12 +197,11 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
past_present_indices = []
start_seq = []
end_seq = []
start_seq_prefill = []
end_seq_prefill = []
max_seqlen = 0
needed_blocks_slots = []
start_slots = []
slot_indices = []
input_lengths = []
prefix_offsets = []
@ -126,7 +223,10 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
@ -138,7 +238,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
@ -153,8 +252,6 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
start_seq_prefill.append(cumulative_length)
end_seq_prefill.append(cumulative_length + input_length)
start_seq.append(cumulative_max_length)
end_seq.append(cumulative_max_length + input_length)
next_token_chooser_parameters.append(r.parameters)
@ -164,6 +261,21 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -184,22 +296,17 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
request_past_present_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
past_present_indices.append(request_past_present_indices)
# Update
# Remove one as the first token des not have a past
cumulative_length += input_length
cumulative_max_length += input_length + max_new_tokens - 1
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
@ -212,14 +319,15 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
slot_indices = torch.cat(slot_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
start_seq_prefill = torch.tensor(
start_seq_prefill, device=device, dtype=torch.int32
@ -227,19 +335,12 @@ class FlashCausalLMBatch(Batch):
end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32
)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
past_present_indices = torch.tensor(
past_present_indices, device=device, dtype=torch.int64
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
@ -262,26 +363,28 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
start_seq_q=None,
end_seq_q=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length,
blocks=blocks,
max_blocks=max_blocks,
)
@tracer.start_as_current_span("filter")
@ -294,28 +397,24 @@ class FlashCausalLMBatch(Batch):
device = self.input_ids.device
# Cumulative length
cumulative_max_length = 0
# New values after filtering
requests_idx_mapping = {}
# Used to index into tensors
indices = []
# past indices to keep
past_indices = torch.zeros(
self.past_key_values.shape[0], dtype=torch.bool, device=device
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_seqlen = 0
requests = []
start_slots = []
block_tables = []
all_input_ids = []
input_lengths = []
@ -324,6 +423,11 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
blocks = 0
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
@ -348,28 +452,51 @@ class FlashCausalLMBatch(Batch):
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
request_block_table = self.block_tables[idx]
blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU)
start_seq[i] = cumulative_max_length
end_seq[i] = cumulative_max_length + request_input_length
slot_indices[i] = cumulative_max_length + request_input_length - 1
# Set slice
past_indices[
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
global CACHE_MANAGER
block_indices_to_free = []
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys():
block_indices_to_free.extend(self.block_tables[i])
# Free blocks
CACHE_MANAGER.free(block_indices_to_free)
# Needed to avoid dropping blocks when the batches will go out of scope
self.block_tables = None
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[past_indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor
start_seq = start_seq.to(device)
end_seq = end_seq.to(device)
past_present_indices = end_seq - 1
slot_indices = slot_indices.to(device)
return FlashCausalLMBatch(
batch_id=self.batch_id,
@ -377,26 +504,28 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length,
blocks=blocks,
max_blocks=max_blocks,
)
@classmethod
@ -406,22 +535,46 @@ class FlashCausalLMBatch(Batch):
requests = []
requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches])
dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device
blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_seqlen = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
)
),
)
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_seq = batches[0].start_seq.new_empty(total_batch_size)
end_seq = batches[0].end_seq.new_empty(total_batch_size)
start_seq_q = torch.arange(
0, total_batch_size, device=device, dtype=torch.int32
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
end_seq_q = start_seq_q + 1
max_seqlen = 0
past_key_values = []
start_slots = []
block_tables = []
all_input_ids = []
input_lengths = []
@ -433,8 +586,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
max_tokens = 0
max_length = 0
cumulative_slots = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -448,16 +600,27 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
start_seq[start_index:end_index] = batch.start_seq + max_tokens
end_seq[start_index:end_index] = batch.end_seq + max_tokens
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
max_seqlen = max(max_seqlen, batch.max_seqlen)
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables)
all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths)
@ -466,73 +629,59 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
cumulative_slots += len(batch.slots)
past_key_values = torch.cat(past_key_values, dim=0)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)
start_slots = torch.concat(start_slots)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
)
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
blocks=blocks,
max_blocks=max_blocks,
)
def __del__(self):
if self.block_tables is not None and self.block_tables:
global CACHE_MANAGER
# Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def __len__(self):
return len(self.requests)
@ -540,32 +689,19 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model):
def __init__(
self,
model_cls: Type[PreTrainedModel],
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_size = head_size
super(FlashCausalLM, self).__init__(
model=model,
@ -573,12 +709,38 @@ class FlashCausalLM(Model):
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
global CACHE_MANAGER
torch.cuda.empty_cache()
try:
CACHE_MANAGER = CacheManager(
# Adds some wiggle room
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)
except Exception as e:
logger.exception(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
)
raise e
del batch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
@ -588,28 +750,27 @@ class FlashCausalLM(Model):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq: torch.Tensor,
end_seq: torch.Tensor,
start_seq_q: Optional[torch.Tensor],
end_seq_q: Optional[torch.Tensor],
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
past_present_indices: torch.Tensor,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
global CACHE_MANAGER
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
start_seq=start_seq,
end_seq=end_seq,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
kv_cache=CACHE_MANAGER.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
past_present_indices=past_present_indices,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
)
@ -617,31 +778,22 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
prefill = batch.start_seq_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
if prefill:
# Ask to pre-allocate kv to its max size
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
pre_allocate_past_size = batch.max_tokens
start_seq = batch.start_seq_prefill
end_seq = batch.end_seq_prefill
else:
pre_allocate_past_size = None
start_seq = batch.start_seq
end_seq = batch.end_seq
if batch.needed_blocks_slots:
# Allocate blocks to this batch
CACHE_MANAGER.allocate(batch)
out, present = self.forward(
out = self.forward(
batch.input_ids,
batch.position_ids,
start_seq,
end_seq,
batch.start_seq_q,
batch.end_seq_q,
batch.start_seq_prefill,
batch.end_seq_prefill,
batch.block_tables_tensor,
batch.slots[batch.slot_indices],
batch.input_lengths_tensor,
batch.max_seqlen,
batch.past_present_indices,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices,
)
@ -662,12 +814,8 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.start_seq_q and batch.end_seq_q for decode
batch.start_seq_q = torch.arange(
0, len(batch), device=self.device, dtype=torch.int32
)
batch.end_seq_q = batch.start_seq_q + 1
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
# We do not need start_seq_prefill and end_seq_prefill anymore
batch.start_seq_prefill = None
batch.end_seq_prefill = None
@ -731,8 +879,8 @@ class FlashCausalLM(Model):
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.past_present_indices = batch.end_seq
batch.end_seq = batch.end_seq + 1
batch.input_lengths_tensor += 1
batch.slot_indices += 1
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -755,7 +903,6 @@ class FlashCausalLM(Model):
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
next_token_ids,
@ -770,7 +917,6 @@ class FlashCausalLM(Model):
read_offset,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
do_sample,
seed,
next_token_id,
@ -845,19 +991,20 @@ class FlashCausalLM(Model):
generations.append(generation)
new_input_length = input_length + 1
# Update values
batch.input_lengths[i] = new_input_length
batch.input_lengths[i] = input_length + 1
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
if stopped:
del batch
# No need to return a batch if we know that all requests stopped
return generations, None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
batch.past_key_values = present
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None
return generations, batch

View File

@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM):
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashNeoXSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM):
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashRWSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -52,17 +52,22 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group,
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashSantacoderSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -22,6 +22,9 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval()
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
@ -55,6 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):
self.generate_token(batch)
def decode_token(
self,
all_input_ids: List[int],

View File

@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,

View File

@ -53,6 +53,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device

View File

@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
self.seeds = seeds
self.do_sample = do_sample
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None:

View File

@ -5,7 +5,14 @@ import torch
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
@ -43,7 +50,7 @@ class Weights:
return str(filename), tensor_name
def _get_slice(self, tensor_name: str):
filename, tensor_name= self.get_filename(tensor_name)
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
@ -94,12 +101,20 @@ class Weights:
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
@ -118,7 +133,9 @@ class Weights:
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)