Logprobs cost too much.
This commit is contained in:
parent
a928b5eb06
commit
61f5d187ab
|
@ -33,7 +33,7 @@ fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Op
|
||||||
if let (Some(config), Some(compute)) = (config, compute) {
|
if let (Some(config), Some(compute)) = (config, compute) {
|
||||||
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
|
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
|
||||||
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
|
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
|
||||||
let optimal_size = (f16_max_compute / model_compute / 2) as usize;
|
let optimal_size = (f16_max_compute / model_compute) as usize;
|
||||||
if optimal_size > 100 {
|
if optimal_size > 100 {
|
||||||
// Ignore calculations that's too low
|
// Ignore calculations that's too low
|
||||||
// Most likely an error
|
// Most likely an error
|
||||||
|
|
|
@ -57,6 +57,7 @@ from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
|
REQUEST_LOGPROBS,
|
||||||
TGI_WIGGLE_ROOM,
|
TGI_WIGGLE_ROOM,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
|
@ -292,6 +293,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
):
|
):
|
||||||
|
### XXX: This consumes so much memory on long requests
|
||||||
|
### Deactivating it by default seems like the best course.
|
||||||
|
if not REQUEST_LOGPROBS:
|
||||||
|
r.prefill_logprobs = False
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
ATTENTION = os.environ["ATTENTION"]
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||||
|
|
Loading…
Reference in New Issue