Sync allocator interfaces
This commit is contained in:
parent
48b21eab7a
commit
27ef5aa029
|
@ -2,6 +2,7 @@ use std::{
|
||||||
cmp::min,
|
cmp::min,
|
||||||
collections::{hash_map::Entry, BTreeSet, HashMap},
|
collections::{hash_map::Entry, BTreeSet, HashMap},
|
||||||
hash::{DefaultHasher, Hash, Hasher},
|
hash::{DefaultHasher, Hash, Hasher},
|
||||||
|
sync::Arc,
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
@ -9,12 +10,14 @@ use tokio::sync::{mpsc, oneshot};
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
pub blocks: Vec<u32>,
|
pub blocks: Vec<u32>,
|
||||||
pub slots: Vec<u32>,
|
pub slots: Vec<u32>,
|
||||||
|
pub allocation_id: u64,
|
||||||
block_allocator: BlockAllocator,
|
block_allocator: BlockAllocator,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.block_allocator.free(self.blocks.clone())
|
self.block_allocator
|
||||||
|
.free(self.blocks.clone(), self.allocation_id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,11 +49,16 @@ impl BlockAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
pub(crate) async fn allocate(
|
||||||
|
&self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.send(BlockAllocatorCommand::Allocate {
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
tokens,
|
tokens,
|
||||||
|
prefill_tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -58,16 +66,20 @@ impl BlockAllocator {
|
||||||
response_receiver
|
response_receiver
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.map(|(blocks, slots)| BlockAllocation {
|
.map(|(blocks, slots, allocation_id)| BlockAllocation {
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
|
allocation_id,
|
||||||
block_allocator: self.clone(),
|
block_allocator: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.send(BlockAllocatorCommand::Free { blocks })
|
.send(BlockAllocatorCommand::Free {
|
||||||
|
allocation_id,
|
||||||
|
blocks,
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,21 +93,32 @@ async fn block_allocator_task(
|
||||||
let mut allocator = SimpleAllocator::new(blocks, block_size, window_size);
|
let mut allocator = SimpleAllocator::new(blocks, block_size, window_size);
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
BlockAllocatorCommand::Free { blocks } => allocator.free(blocks),
|
BlockAllocatorCommand::Free {
|
||||||
|
blocks,
|
||||||
|
allocation_id,
|
||||||
|
} => allocator.free(blocks, allocation_id),
|
||||||
BlockAllocatorCommand::Allocate {
|
BlockAllocatorCommand::Allocate {
|
||||||
tokens,
|
tokens,
|
||||||
|
prefill_tokens,
|
||||||
response_sender,
|
response_sender,
|
||||||
} => {
|
} => {
|
||||||
response_sender.send(allocator.allocate(tokens)).unwrap();
|
let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
|
||||||
|
response_sender
|
||||||
|
.send(allocator.allocate(tokens, prefill_tokens_slice))
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Allocator {
|
pub trait Allocator {
|
||||||
fn allocate(&mut self, tokens: u32) -> Option<(Vec<u32>, Vec<u32>)>;
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<&[u32]>,
|
||||||
|
) -> Option<(Vec<u32>, Vec<u32>, u64)>;
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>);
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SimpleAllocator {
|
pub struct SimpleAllocator {
|
||||||
|
@ -116,7 +139,11 @@ impl SimpleAllocator {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Allocator for SimpleAllocator {
|
impl Allocator for SimpleAllocator {
|
||||||
fn allocate(&mut self, tokens: u32) -> Option<(Vec<u32>, Vec<u32>)> {
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
_prefill_tokens: Option<&[u32]>,
|
||||||
|
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
|
||||||
// Apply window size
|
// Apply window size
|
||||||
let (required_blocks, repeats) = {
|
let (required_blocks, repeats) = {
|
||||||
let (tokens, repeats) = match self.window_size {
|
let (tokens, repeats) = match self.window_size {
|
||||||
|
@ -150,11 +177,11 @@ impl Allocator for SimpleAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some((blocks, slots))
|
Some((blocks, slots, 0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>) {
|
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||||
self.free_blocks.extend(blocks)
|
self.free_blocks.extend(blocks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -163,20 +190,15 @@ impl Allocator for SimpleAllocator {
|
||||||
enum BlockAllocatorCommand {
|
enum BlockAllocatorCommand {
|
||||||
Free {
|
Free {
|
||||||
blocks: Vec<u32>,
|
blocks: Vec<u32>,
|
||||||
|
allocation_id: u64,
|
||||||
},
|
},
|
||||||
Allocate {
|
Allocate {
|
||||||
tokens: u32,
|
tokens: u32,
|
||||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>, u64)>>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
|
||||||
pub struct BlockAllocationWithCache {
|
|
||||||
pub blocks: Vec<u32>,
|
|
||||||
pub slots: Vec<u32>,
|
|
||||||
pub allocation: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Allocation {
|
struct Allocation {
|
||||||
cache_prefixes: Vec<u64>,
|
cache_prefixes: Vec<u64>,
|
||||||
|
@ -235,87 +257,6 @@ impl PrefixCacheAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn alloc(
|
|
||||||
&mut self,
|
|
||||||
n_tokens: usize,
|
|
||||||
prefill_tokens: &[u32],
|
|
||||||
) -> Option<BlockAllocationWithCache> {
|
|
||||||
let mut hasher = DefaultHasher::new();
|
|
||||||
let mut prefix_cache_blocks = Vec::new();
|
|
||||||
|
|
||||||
// Find hashes for all block_sized prefill chunks.
|
|
||||||
let mut prefix_hashes = Vec::new();
|
|
||||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
|
||||||
if prefill_chunk.len() < self.block_size {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
prefill_chunk.hash(&mut hasher);
|
|
||||||
prefix_hashes.push(hasher.finish());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut n_from_cache = 0;
|
|
||||||
for prefix_hash in prefix_hashes.iter() {
|
|
||||||
let block_id = match self.cache_blocks.get(prefix_hash) {
|
|
||||||
Some(state) => state.block_id,
|
|
||||||
None => break,
|
|
||||||
};
|
|
||||||
|
|
||||||
// We have to acquire the prefixes blocks, even if the allocation fails
|
|
||||||
// later, otherwise the allocation below could garbage collect the
|
|
||||||
// prefix blocks.
|
|
||||||
self.incref_prefix(*prefix_hash);
|
|
||||||
prefix_cache_blocks.push(block_id);
|
|
||||||
n_from_cache += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_prefixes = prefix_hashes.split_off(n_from_cache);
|
|
||||||
let cache_prefixes = prefix_hashes;
|
|
||||||
|
|
||||||
// Get tokens for the remaining prefill and decode.
|
|
||||||
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
|
|
||||||
Some(blocks) => blocks,
|
|
||||||
None => {
|
|
||||||
// If the allocation fails, we have relinquish our use of the
|
|
||||||
// prefix cache blocks. Maybe we can do this using `Drop`?
|
|
||||||
for prefix_hash in cache_prefixes {
|
|
||||||
self.decref_prefix(prefix_hash);
|
|
||||||
}
|
|
||||||
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
prefix_cache_blocks.extend(blocks);
|
|
||||||
|
|
||||||
let mut slots = Vec::with_capacity(n_tokens);
|
|
||||||
for block_id in prefix_cache_blocks.iter() {
|
|
||||||
for s in
|
|
||||||
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
|
|
||||||
{
|
|
||||||
slots.push(s);
|
|
||||||
if slots.len() == n_tokens {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let allocation = Allocation {
|
|
||||||
cache_prefixes,
|
|
||||||
new_prefixes,
|
|
||||||
};
|
|
||||||
|
|
||||||
let allocation_id = self.time;
|
|
||||||
self.time += 1;
|
|
||||||
self.allocations.insert(allocation_id, allocation);
|
|
||||||
|
|
||||||
Some(BlockAllocationWithCache {
|
|
||||||
blocks: prefix_cache_blocks,
|
|
||||||
slots,
|
|
||||||
allocation: allocation_id,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn free_prefix_block(&mut self, prefix_hash: u64) {
|
fn free_prefix_block(&mut self, prefix_hash: u64) {
|
||||||
let state = self
|
let state = self
|
||||||
.cache_blocks
|
.cache_blocks
|
||||||
|
@ -368,8 +309,92 @@ impl PrefixCacheAllocator {
|
||||||
.split_off(self.free_blocks.len() - n_blocks_needed),
|
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn free(&mut self, blocks: &[u32], allocation: u64) {
|
impl Allocator for PrefixCacheAllocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
n_tokens: u32,
|
||||||
|
prefill_tokens: Option<&[u32]>,
|
||||||
|
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
let mut prefix_cache_blocks = Vec::new();
|
||||||
|
|
||||||
|
// Find hashes for all block_sized prefill chunks.
|
||||||
|
let mut prefix_hashes = Vec::new();
|
||||||
|
let mut n_from_cache = 0;
|
||||||
|
|
||||||
|
// If we don't have a fast tokenizer, can't do prefix caching.
|
||||||
|
if let Some(prefill_tokens) = prefill_tokens {
|
||||||
|
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||||
|
if prefill_chunk.len() < self.block_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill_chunk.hash(&mut hasher);
|
||||||
|
prefix_hashes.push(hasher.finish());
|
||||||
|
}
|
||||||
|
|
||||||
|
for prefix_hash in prefix_hashes.iter() {
|
||||||
|
let block_id = match self.cache_blocks.get(prefix_hash) {
|
||||||
|
Some(state) => state.block_id,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
|
||||||
|
// We have to acquire the prefixes blocks, even if the allocation fails
|
||||||
|
// later, otherwise the allocation below could garbage collect the
|
||||||
|
// prefix blocks.
|
||||||
|
self.incref_prefix(*prefix_hash);
|
||||||
|
prefix_cache_blocks.push(block_id);
|
||||||
|
n_from_cache += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_prefixes = prefix_hashes.split_off(n_from_cache);
|
||||||
|
let cache_prefixes = prefix_hashes;
|
||||||
|
|
||||||
|
// Get tokens for the remaining prefill and decode.
|
||||||
|
let blocks =
|
||||||
|
match self.alloc_or_reclaim(n_tokens as usize - (n_from_cache * self.block_size)) {
|
||||||
|
Some(blocks) => blocks,
|
||||||
|
None => {
|
||||||
|
// If the allocation fails, we have relinquish our use of the
|
||||||
|
// prefix cache blocks. Maybe we can do this using `Drop`?
|
||||||
|
for prefix_hash in cache_prefixes {
|
||||||
|
self.decref_prefix(prefix_hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
prefix_cache_blocks.extend(blocks);
|
||||||
|
|
||||||
|
let mut slots = Vec::with_capacity(n_tokens as usize);
|
||||||
|
for block_id in prefix_cache_blocks.iter() {
|
||||||
|
for s in
|
||||||
|
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
|
||||||
|
{
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() == n_tokens as usize {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let allocation = Allocation {
|
||||||
|
cache_prefixes,
|
||||||
|
new_prefixes,
|
||||||
|
};
|
||||||
|
|
||||||
|
let allocation_id = self.time;
|
||||||
|
self.time += 1;
|
||||||
|
self.allocations.insert(allocation_id, allocation);
|
||||||
|
|
||||||
|
Some((prefix_cache_blocks, slots, allocation_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, allocation: u64) {
|
||||||
let allocation = match self.allocations.remove(&allocation) {
|
let allocation = match self.allocations.remove(&allocation) {
|
||||||
Some(allocation) => allocation,
|
Some(allocation) => allocation,
|
||||||
None => unreachable!("Tried to free an unknown allocation."),
|
None => unreachable!("Tried to free an unknown allocation."),
|
||||||
|
@ -433,8 +458,6 @@ impl PrefixCacheAllocator {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::infer::v3::block_allocator::BlockAllocationWithCache;
|
|
||||||
|
|
||||||
use super::PrefixCacheAllocator;
|
use super::PrefixCacheAllocator;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -298,7 +298,10 @@ impl State {
|
||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
|
|
||||||
match block_allocator.allocate(tokens).await {
|
match block_allocator
|
||||||
|
.allocate(tokens, entry.request.input_ids.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
None => {
|
None => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
|
|
|
@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_client::{Chunk, Image, InputChunk};
|
use text_generation_client::{Chunk, Image, InputChunk};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
@ -122,7 +123,7 @@ impl Validation {
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
) -> Result<(Vec<InputChunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
|
@ -157,8 +158,10 @@ impl Validation {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let input_ids = encoding.get_ids()[..input_length].to_owned();
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Return inputs without validation
|
||||||
else {
|
else {
|
||||||
|
@ -183,6 +186,7 @@ impl Validation {
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
vec![Chunk::Text(inputs).into()],
|
vec![Chunk::Text(inputs).into()],
|
||||||
|
None,
|
||||||
input_length,
|
input_length,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
))
|
))
|
||||||
|
@ -319,7 +323,7 @@ impl Validation {
|
||||||
.unwrap_or(Ok(None))?;
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_length, max_new_tokens) = self
|
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -388,6 +392,7 @@ impl Validation {
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
|
input_ids: input_ids.map(Arc::new),
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
input_length: input_length as u32,
|
input_length: input_length as u32,
|
||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
|
@ -671,6 +676,7 @@ pub(crate) struct ValidStoppingParameters {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: Vec<InputChunk>,
|
pub inputs: Vec<InputChunk>,
|
||||||
|
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
|
|
Loading…
Reference in New Issue