Logprobs cost too much.

This commit is contained in:
Nicolas Patry 2024-11-10 07:00:22 +01:00
parent a928b5eb06
commit 61f5d187ab
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 7 additions and 1 deletions

View File

@ -33,7 +33,7 @@ fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Op
if let (Some(config), Some(compute)) = (config, compute) { if let (Some(config), Some(compute)) = (config, compute) {
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) { if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}"); tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
let optimal_size = (f16_max_compute / model_compute / 2) as usize; let optimal_size = (f16_max_compute / model_compute) as usize;
if optimal_size > 100 { if optimal_size > 100 {
// Ignore calculations that's too low // Ignore calculations that's too low
// Most likely an error // Most likely an error

View File

@ -57,6 +57,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
@ -292,6 +293,10 @@ class FlashCausalLMBatch(Batch):
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
): ):
### XXX: This consumes so much memory on long requests
### Deactivating it by default seems like the best course.
if not REQUEST_LOGPROBS:
r.prefill_logprobs = False
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i

View File

@ -5,6 +5,7 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"] ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {