Fix disabling prefix caching - Fix windowing checks.

This commit is contained in:
Nicolas Patry 2024-08-29 11:34:13 +02:00
parent bef2f6bdaa
commit fc7ea202c2
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 11 additions and 9 deletions

View File

@ -104,6 +104,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}

View File

@ -1748,7 +1748,7 @@ pub async fn run(
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" || class == "CohereTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);

View File

@ -497,12 +497,11 @@ def get_model(
else -1
)
should_use_sliding_window = (
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
use_sliding_window = sliding_window is not None and sliding_window != -1
needs_sliding_window = (
max_input_tokens is not None and max_input_tokens > sliding_window
)
if should_use_sliding_window:
if max_input_tokens is not None and max_input_tokens > sliding_window:
if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)