This seems to be working.
This commit is contained in:
parent
f5182c188c
commit
26e5037de4
|
@ -158,34 +158,40 @@ impl Allocator for RadixAllocator {
|
|||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
assert_eq!(prefill_tokens.len() % self.block_size as usize, 0);
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(
|
||||
prefill_tokens,
|
||||
&blocks[..prefill_tokens.len() / self.block_size as usize],
|
||||
)
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
let aligned =
|
||||
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||
if aligned > 0 {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(
|
||||
&prefill_tokens[..aligned],
|
||||
&blocks[..aligned / self.block_size as usize],
|
||||
)
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
|
||||
// We can have a prefill with the following structure:
|
||||
//
|
||||
// |---| From the prefix cache.
|
||||
// A B C D E F G
|
||||
//|--------| Found in the trie during insertion.
|
||||
//
|
||||
// This means that while processing this request there was a
|
||||
// partially overlapping request that had A..=E in its
|
||||
// prefill. In this case we need to free the blocks D E.
|
||||
self.free_blocks
|
||||
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
|
||||
// We can have a prefill with the following structure:
|
||||
//
|
||||
// |---| From the prefix cache.
|
||||
// A B C D E F G
|
||||
//|--------| Found in the trie during insertion.
|
||||
//
|
||||
// This means that while processing this request there was a
|
||||
// partially overlapping request that had A..=E in its
|
||||
// prefill. In this case we need to free the blocks D E.
|
||||
self.free_blocks.extend(
|
||||
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||
..prefix_len / self.block_size as usize],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Free non-prefill blocks.
|
||||
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
|
||||
self.free_blocks
|
||||
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||
} else {
|
||||
self.free_blocks.extend(blocks);
|
||||
}
|
||||
|
@ -605,6 +611,24 @@ mod tests {
|
|||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size_non_aligned() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(
|
||||
allocation.blocks[..allocation.blocks.len() - 1].to_vec(),
|
||||
allocation.allocation_id,
|
||||
);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
|
|
|
@ -83,7 +83,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
"Forcing flash decoding because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
prefix_caching = Some("0".to_string());
|
||||
prefix_caching = Some("1".to_string());
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
_ => {
|
||||
if prefix_caching.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
prefix_caching = Some("0".to_string());
|
||||
prefix_caching = Some("1".to_string());
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
|
@ -1000,7 +1000,7 @@ impl TryFrom<&[u8]> for PythonLogMessage {
|
|||
}
|
||||
|
||||
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||
let mut buffer = vec![0u8; 4096];
|
||||
let mut buffer = vec![0u8; 8 * 4096];
|
||||
let mut stdout = std::io::stdout();
|
||||
loop {
|
||||
let n = bufread.read(&mut buffer);
|
||||
|
|
|
@ -265,13 +265,10 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
prefix_len = r.prefix_len
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
else:
|
||||
prefix_len = 0
|
||||
prefix_len = r.prefix_len
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
|
||||
prefix_ids.append(tokenized_input[:prefix_len])
|
||||
tokenized_input = tokenized_input[prefix_len:]
|
||||
|
|
|
@ -14,8 +14,8 @@ assert (
|
|||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION != "flashinfer":
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
# if PREFIX_CACHING and ATTENTION != "flashinfer":
|
||||
# raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
||||
|
|
Loading…
Reference in New Issue