Logprobs cost too much.
This commit is contained in:
parent
a928b5eb06
commit
61f5d187ab
|
@ -33,7 +33,7 @@ fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Op
|
|||
if let (Some(config), Some(compute)) = (config, compute) {
|
||||
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
|
||||
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
|
||||
let optimal_size = (f16_max_compute / model_compute / 2) as usize;
|
||||
let optimal_size = (f16_max_compute / model_compute) as usize;
|
||||
if optimal_size > 100 {
|
||||
// Ignore calculations that's too low
|
||||
// Most likely an error
|
||||
|
|
|
@ -57,6 +57,7 @@ from text_generation_server.models.globals import (
|
|||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
REQUEST_LOGPROBS,
|
||||
TGI_WIGGLE_ROOM,
|
||||
get_adapter_to_index,
|
||||
)
|
||||
|
@ -292,6 +293,10 @@ class FlashCausalLMBatch(Batch):
|
|||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
### XXX: This consumes so much memory on long requests
|
||||
### Deactivating it by default seems like the best course.
|
||||
if not REQUEST_LOGPROBS:
|
||||
r.prefill_logprobs = False
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Dict, Optional
|
|||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||
ATTENTION = os.environ["ATTENTION"]
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||
|
|
Loading…
Reference in New Issue