Updating logic + non flash.

This commit is contained in:
Nicolas Patry 2024-10-24 09:58:05 +02:00
parent 10534511ea
commit 6994fa12f8
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 43 additions and 13 deletions

View File

@ -137,7 +137,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
n_positions: Option<usize>,
model_type: Option<String>,
max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
@ -157,6 +160,7 @@ struct VisionConfig {}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
@ -166,6 +170,10 @@ struct Config {
impl From<RawConfig> for Config {
fn from(other: RawConfig) -> Self {
let max_position_embeddings = other
.max_position_embeddings
.or(other.max_seq_len)
.or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
@ -187,6 +195,7 @@ impl From<RawConfig> for Config {
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
@ -479,7 +488,7 @@ struct Args {
/// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be.
/// Default to min(max_allocatable, max_position_embeddings)
/// Default to min(max_position_embeddings, 4096)
#[clap(long, env)]
max_total_tokens: Option<usize>,
@ -1667,6 +1676,28 @@ fn main() -> Result<(), LauncherError> {
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
} else {
max_position_embeddings
}
} else {
max_default
}
} else {
max_default
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("PREFIX_CACHING", prefix_caching);
@ -1690,15 +1721,8 @@ fn main() -> Result<(), LauncherError> {
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
// let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
// max_batch_size * max_input_tokens
// } else {
// // Adding some edge in order to account for potential block_size alignement
// // issue.
// max_input_tokens + 50
// } as u32;
// TODO figure out hardware optimal value
let value = 4096;
let value = 4096.min(max_position_embeddings as u32);
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
}

View File

@ -1412,7 +1412,7 @@ class FlashCausalLM(Model):
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
_, batch, _ = self.generate_token(batch)
_, _batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
@ -1442,14 +1442,14 @@ class FlashCausalLM(Model):
max_total_tokens = num_blocks * BLOCK_SIZE
else:
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
max_total_tokens = sum(batch.cache_lengths)
max_input_tokens = (
max_total_tokens - 1
if max_input_tokens is None
else max_input_tokens
)
del batch
del _batch, batch
self.kv_cache = []
empty_cache()

View File

@ -132,7 +132,13 @@ class Model(ABC):
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Tuple[Optional[int], int, int]:
self.generate_token(batch)
return None, 0, 0
total = sum(len(i) for i in batch.input_ids)
if max_total_tokens is None:
max_total_tokens = total
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def decode_token(
self,