parent
9ffe1f1e67
commit
8aece3bd68
|
@ -26,7 +26,12 @@ incremental = true
|
|||
inherits = "release"
|
||||
debug = 1
|
||||
incremental = true
|
||||
panic = "abort"
|
||||
|
||||
[profile.release-opt]
|
||||
inherits = "release"
|
||||
debug = 0
|
||||
incremental = false
|
||||
lto = "fat"
|
||||
opt-level = 3
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
|
10
Dockerfile
10
Dockerfile
|
@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -33,7 +33,7 @@ COPY proto proto
|
|||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
|
@ -226,11 +226,11 @@ RUN cd server && \
|
|||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
|
|
|
@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -33,7 +33,7 @@ COPY proto proto
|
|||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
||||
|
@ -193,11 +193,11 @@ RUN cd server && \
|
|||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base as sagemaker
|
||||
|
|
|
@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||
rm -f $PROTOC_ZIP
|
||||
|
||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -32,7 +32,7 @@ COPY proto proto
|
|||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
|
@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
|
|||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# Final image
|
||||
FROM base
|
||||
|
|
|
@ -155,6 +155,8 @@ async fn prefill(
|
|||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||
}),
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -163,6 +165,7 @@ async fn prefill(
|
|||
requests,
|
||||
size: batch_size,
|
||||
max_tokens: batch_size * (sequence_length + decode_length),
|
||||
max_blocks: 0,
|
||||
};
|
||||
|
||||
// Run prefill
|
||||
|
|
|
@ -130,6 +130,10 @@ message Request {
|
|||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
/// Paged attention blocks
|
||||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -141,6 +145,8 @@ message Batch {
|
|||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
/// Maximum number of Paged Attention blocks
|
||||
uint32 max_blocks = 5;
|
||||
}
|
||||
|
||||
message CachedBatch {
|
||||
|
|
|
@ -153,6 +153,9 @@ impl Client {
|
|||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -187,7 +190,8 @@ impl Client {
|
|||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
max_tokens: max_input_length,
|
||||
max_blocks: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
|
|
|
@ -241,12 +241,16 @@ impl Health for ShardedClient {
|
|||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
use std::cmp::min;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
block_allocator: BlockAllocator,
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
self.block_allocator.free(self.blocks.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocator {
|
||||
/// Channel to communicate with the background task
|
||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||
}
|
||||
|
||||
impl BlockAllocator {
|
||||
pub(crate) fn new(
|
||||
max_batch_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(block_allocator_task(
|
||||
max_batch_total_tokens / block_size,
|
||||
block_size,
|
||||
window_size,
|
||||
receiver,
|
||||
));
|
||||
|
||||
Self {
|
||||
block_allocator: sender,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
response_receiver
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|(blocks, slots)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
block_allocator: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free { blocks })
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn block_allocator_task(
|
||||
blocks: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
// Block 0 is reserved for health checks
|
||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks =
|
||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots))
|
||||
};
|
||||
response_sender.send(allocation).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||
},
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod block_allocator;
|
||||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v3::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::{ChunksToString, Input};
|
||||
use text_generation_client::ChunksToString;
|
||||
use text_generation_client::Input;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
/// Queue entry
|
||||
#[derive(Debug)]
|
||||
|
@ -28,6 +31,8 @@ pub(crate) struct Entry {
|
|||
pub queue_time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
pub batch_time: Option<Instant>,
|
||||
/// Block Allocation
|
||||
pub block_allocation: Option<BlockAllocation>,
|
||||
}
|
||||
|
||||
/// Request Queue
|
||||
|
@ -43,6 +48,7 @@ impl Queue {
|
|||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
@ -53,12 +59,14 @@ impl Queue {
|
|||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Send append command to the background task managing the state
|
||||
|
@ -103,9 +111,16 @@ async fn queue_task(
|
|||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
|
@ -120,12 +135,14 @@ async fn queue_task(
|
|||
token_budget,
|
||||
response_sender,
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
} => {
|
||||
let next_batch = state
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.instrument(span)
|
||||
.await;
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -142,9 +159,6 @@ struct State {
|
|||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
|
@ -153,6 +167,9 @@ struct State {
|
|||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
@ -161,15 +178,19 @@ impl State {
|
|||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding)
|
||||
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
block_allocator,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,7 +206,7 @@ impl State {
|
|||
}
|
||||
|
||||
// Get the next batch
|
||||
fn next_batch(
|
||||
async fn next_batch(
|
||||
&mut self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
|
@ -220,9 +241,10 @@ impl State {
|
|||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
let mut max_blocks = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
|
@ -231,21 +253,27 @@ impl State {
|
|||
continue;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
let block_allocation = match &self.block_allocator {
|
||||
None => {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
} else {
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
None
|
||||
}
|
||||
Some(block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
|
@ -253,11 +281,7 @@ impl State {
|
|||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
}
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
|
@ -269,6 +293,28 @@ impl State {
|
|||
break;
|
||||
}
|
||||
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
match block_allocator.allocate(tokens).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
|
@ -278,13 +324,23 @@ impl State {
|
|||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
let (blocks, slots) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new()),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks.clone(),
|
||||
block_allocation.slots.clone(),
|
||||
),
|
||||
};
|
||||
|
||||
entry.block_allocation = block_allocation;
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
input_chunks: Some(Input {
|
||||
chunks: entry.request.inputs.clone(),
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
|
@ -293,6 +349,8 @@ impl State {
|
|||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
@ -335,6 +393,7 @@ impl State {
|
|||
requests: batch_requests,
|
||||
size,
|
||||
max_tokens: (prefill_tokens + decode_tokens),
|
||||
max_blocks,
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
@ -438,13 +497,14 @@ mod tests {
|
|||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
|
@ -458,23 +518,23 @@ mod tests {
|
|||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
|
@ -490,7 +550,7 @@ mod tests {
|
|||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
|
@ -498,15 +558,15 @@ mod tests {
|
|||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
|
@ -518,15 +578,15 @@ mod tests {
|
|||
assert_eq!(state.next_batch_id, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
|
@ -539,7 +599,7 @@ mod tests {
|
|||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
|
@ -553,14 +613,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -568,7 +628,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -601,7 +661,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -617,7 +677,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -642,7 +702,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2);
|
||||
let queue = Queue::new(false, 1, None, 2, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -661,7 +721,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
|
|
@ -39,7 +39,13 @@ impl SchedulerV3 {
|
|||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
|
@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
|
|||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
repeat_slots: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
self.repeat_slots = repeat_slots
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int64
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(
|
||||
self,
|
||||
needed_blocks_slots: List[Tuple[int, int]],
|
||||
blocks: int,
|
||||
max_blocks: int,
|
||||
device: torch.device,
|
||||
):
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
if blocks > len(free_block_indices):
|
||||
raise RuntimeError(
|
||||
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
)
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[:blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
all_slots = self.slots[allocated_blocks].flatten()
|
||||
|
||||
# Repeat slots in the case of context sliding window
|
||||
if needed_slots > len(all_slots) and self.repeat_slots:
|
||||
repeats = math.ceil(needed_slots / len(all_slots))
|
||||
all_slots = all_slots.repeat(repeats)
|
||||
|
||||
allocated_slots = all_slots[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
cumulative_blocks += needed_blocks
|
||||
|
||||
block_tables = block_tables
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
slots = torch.concat(slots).to(device)
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
return block_tables, block_tables_tensor, slots
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
def set_cache_manager(
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
repeat_slots: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> CacheManager:
|
||||
global CACHE_MANAGER
|
||||
if CACHE_MANAGER is not None:
|
||||
del CACHE_MANAGER
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
||||
)
|
||||
return CACHE_MANAGER
|
||||
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
global CACHE_MANAGER
|
||||
if CACHE_MANAGER is None:
|
||||
raise RuntimeError("cache manager was not initialized")
|
||||
|
||||
return CACHE_MANAGER
|
|
@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
|
|
|
@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
|
|
|
@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
|
|
|
@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(
|
||||
|
|
|
@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
|
|
|
@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
|
|
|
@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
|
|
|
@ -25,11 +25,6 @@ from text_generation_server.models.types import (
|
|||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
set_cache_manager,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||
import text_generation_server.models.globals as tgi_globals
|
||||
|
@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
|
|||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int):
|
||||
global SLIDING_WINDOW
|
||||
SLIDING_WINDOW = sliding_window
|
||||
|
||||
|
||||
def get_sliding_windows() -> int:
|
||||
global SLIDING_WINDOW
|
||||
return SLIDING_WINDOW
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
|
@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
|
|||
# Decoder values
|
||||
input_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
speculative_ids: torch.Tensor
|
||||
speculative_ids: Optional[torch.Tensor]
|
||||
|
||||
# Flash Attention values
|
||||
|
||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||
prefill_cache_indices: Optional[torch.Tensor]
|
||||
|
||||
# Paged Attention values
|
||||
|
||||
|
@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
|
|||
start_slots: torch.Tensor
|
||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||
slot_indices: torch.Tensor
|
||||
# List of tuple of ints representing the number of blocks and slots needed by each sequence
|
||||
needed_blocks_slots: Optional[List[Tuple[int, int]]]
|
||||
|
||||
# Set in prefill by the CacheManager
|
||||
# list of length b of list of length s_i // block_size
|
||||
block_tables: Optional[List[List[int]]]
|
||||
block_tables: List[List[int]]
|
||||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||
block_tables_tensor: Optional[torch.Tensor]
|
||||
block_tables_tensor: torch.Tensor
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
slots: Optional[torch.Tensor]
|
||||
slots: torch.Tensor
|
||||
|
||||
max_seqlen: int
|
||||
|
||||
|
@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Number of blocks in this batch
|
||||
blocks: int
|
||||
num_blocks: int
|
||||
# Maximum number of blocks
|
||||
max_blocks: int
|
||||
|
||||
|
@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
|
|||
id=self.batch_id,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.blocks * BLOCK_SIZE,
|
||||
max_tokens=self.num_blocks * BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
|
|||
)["input_ids"]
|
||||
return batch_tokenized_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
cls,
|
||||
|
@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window = get_sliding_windows()
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
|
@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
|
|||
cumulative_max_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
blocks = 0
|
||||
num_blocks = 0
|
||||
max_seqlen = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
|
||||
block_tables = []
|
||||
slots = []
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
|
@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
|
|||
speculative_length = get_speculate()
|
||||
speculative_length = 0 if speculative_length is None else speculative_length
|
||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
|
||||
# blocks and slots can be empty (for example in warmup)
|
||||
if not r.blocks:
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
blocks += needed_blocks
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
request_blocks = [
|
||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||
]
|
||||
request_slots = [
|
||||
s
|
||||
for b in request_blocks
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_blocks = r.blocks
|
||||
request_slots = r.slots
|
||||
|
||||
block_tables.append(request_blocks)
|
||||
slots.extend(request_slots[:total_tokens])
|
||||
num_blocks += len(request_blocks)
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
|
@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
|
@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
|
|||
cumulative_length += input_length
|
||||
cumulative_max_length += total_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_blocks = max(max_blocks, needed_blocks)
|
||||
max_blocks = max(max_blocks, len(request_blocks))
|
||||
max_length = max(
|
||||
max_length, input_length + max_new_tokens + speculative_length
|
||||
)
|
||||
|
@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
|
|||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
input_lengths, dtype=torch.int32, device=device
|
||||
|
@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
|
@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
|
|||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
block_tables=None,
|
||||
block_tables_tensor=None,
|
||||
slots=None,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
|
@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
if len(request_ids) == 0:
|
||||
|
@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
blocks = 0
|
||||
num_blocks = 0
|
||||
max_blocks = 0
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
|
@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
request_block_table = self.block_tables[idx]
|
||||
blocks += len(request_block_table)
|
||||
num_blocks += len(request_block_table)
|
||||
block_tables.append(request_block_table)
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
|
@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
block_indices_to_free = []
|
||||
# Iterate on all requests
|
||||
for i, r in enumerate(self.requests):
|
||||
# Filter requests that are not part of the new batch
|
||||
if r.id not in requests_idx_mapping.keys():
|
||||
block_indices_to_free.extend(self.block_tables[i])
|
||||
# Free blocks
|
||||
get_cache_manager().free(block_indices_to_free)
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
self.block_tables = None
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
|
@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
|
|||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
prefill_cache_indices=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=None,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
|
@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
)
|
||||
|
@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
|
|||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
blocks = 0
|
||||
num_blocks = 0
|
||||
total_batch_size = 0
|
||||
total_slots = 0
|
||||
max_blocks = 0
|
||||
|
@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
|
|||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
blocks += b.blocks
|
||||
num_blocks += b.num_blocks
|
||||
speculative_length = (
|
||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||
)
|
||||
|
@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
|
|||
else None
|
||||
)
|
||||
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
for b in batches:
|
||||
b.block_tables = None
|
||||
del b
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
|
|||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
prefill_cache_indices=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=None,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
|
@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if self.block_tables is not None and self.block_tables:
|
||||
# Free blocks
|
||||
get_cache_manager().free(
|
||||
list(itertools.chain.from_iterable(self.block_tables))
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
|
@ -702,6 +732,7 @@ class FlashCausalLM(Model):
|
|||
self.head_size = head_size
|
||||
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -718,6 +749,43 @@ class FlashCausalLM(Model):
|
|||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def max_past(self) -> int:
|
||||
return getattr(self.model, "max_past", None)
|
||||
|
||||
def init_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
|
@ -728,12 +796,11 @@ class FlashCausalLM(Model):
|
|||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": kv_cache,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
|
@ -747,11 +814,12 @@ class FlashCausalLM(Model):
|
|||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -761,11 +829,12 @@ class FlashCausalLM(Model):
|
|||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
|
@ -777,17 +846,16 @@ class FlashCausalLM(Model):
|
|||
empty_cache()
|
||||
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
batch.blocks,
|
||||
self.init_kv_cache(
|
||||
batch.num_blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.sliding_window is not None,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
max_bt = batch.max_blocks
|
||||
max_s = max_bt * get_cache_manager().block_size
|
||||
max_s = max_bt * BLOCK_SIZE
|
||||
|
||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
@ -811,19 +879,17 @@ class FlashCausalLM(Model):
|
|||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * 0.95) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ cache_manager.num_blocks
|
||||
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ batch.num_blocks
|
||||
)
|
||||
|
||||
del batch
|
||||
del cache_manager
|
||||
|
||||
set_cache_manager(
|
||||
self.init_kv_cache(
|
||||
num_blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.sliding_window is not None,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
|
@ -889,7 +955,6 @@ class FlashCausalLM(Model):
|
|||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
@ -901,12 +966,13 @@ class FlashCausalLM(Model):
|
|||
cu_seqlen_prefill=torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=input_lengths,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -917,7 +983,7 @@ class FlashCausalLM(Model):
|
|||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
|
@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
|
|||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
|
@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
|
|||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
|
@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
|
|||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
|
@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
|
|||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
if batch.needed_blocks_slots:
|
||||
# Allocate blocks to this batch
|
||||
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
|
||||
batch.needed_blocks_slots,
|
||||
batch.blocks,
|
||||
batch.max_blocks,
|
||||
batch.input_ids.device,
|
||||
)
|
||||
batch.needed_blocks_slots = None
|
||||
batch.block_tables = block_tables
|
||||
batch.block_tables_tensor = block_tables_tensor
|
||||
batch.slots = slots
|
||||
|
||||
try:
|
||||
out, speculative_logits = self.forward(batch)
|
||||
except Exception as e:
|
||||
del batch
|
||||
raise e
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
|
@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
|
|||
batch.all_input_ids[i] = all_input_ids
|
||||
|
||||
if stopped:
|
||||
del batch
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
|
|
|
@ -1,308 +1,24 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, Type
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
)
|
||||
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
MistralConfig,
|
||||
)
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
SLIDING_WINDOW = sliding_window
|
||||
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
|
||||
|
||||
|
||||
def get_sliding_windows() -> Tuple[int, int]:
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
|
||||
|
||||
|
||||
# Adds windowing logic to FlashCausalLMBatch
|
||||
@dataclass
|
||||
class FlashMistralBatch(FlashCausalLMBatch):
|
||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
batch_tokenized_inputs,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
blocks = 0
|
||||
max_seqlen = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
if (
|
||||
tokenized_input[0] == tokenizer.bos_token_id
|
||||
and tokenized_input[1] == tokenizer.bos_token_id
|
||||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
speculative_length = get_speculate()
|
||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
|
||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
if sliding_window_blocks is not None:
|
||||
needed_blocks = min(needed_blocks, sliding_window_blocks)
|
||||
blocks += needed_blocks
|
||||
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_max_length,
|
||||
cumulative_max_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
if r.prefill_logprobs:
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
cumulative_max_length += total_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_blocks = max(max_blocks, needed_blocks)
|
||||
max_length = max(
|
||||
max_length, input_length + max_new_tokens + speculative_length
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
# Padded all_input_ids_tensor
|
||||
all_input_ids_tensor = np.zeros(
|
||||
(len(all_input_ids), max_length), dtype=np.int64
|
||||
)
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||
)
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
block_tables=None,
|
||||
block_tables_tensor=None,
|
||||
slots=None,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
prefill_cu_outlens=prefill_cu_outlens,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
speculative_ids=None,
|
||||
)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
|
@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
|
||||
# Set context windows
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
set_sliding_window(config.sliding_window)
|
||||
else:
|
||||
config.sliding_window = None
|
||||
|
||||
|
@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> int:
|
||||
return self.model.max_past
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=input_lengths,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(
|
||||
self, batch: FlashMistralBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
|
|
|
@ -7,7 +7,6 @@ from opentelemetry import trace
|
|||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
|
@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
set_sliding_window(config.sliding_window)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import Optional
|
|||
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
|
@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
set_sliding_window(config.sliding_window)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
|
|
@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
|
|||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistralBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
|||
return image
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashMistralBatch):
|
||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
|
@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
|
@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
|
|
Loading…
Reference in New Issue