Updating logic + non flash.
This commit is contained in:
parent
10534511ea
commit
6994fa12f8
|
@ -137,7 +137,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
|
||||
#[derive(Deserialize)]
|
||||
struct RawConfig {
|
||||
max_position_embeddings: Option<usize>,
|
||||
n_positions: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
max_seq_len: Option<usize>,
|
||||
quantization_config: Option<QuantizationConfig>,
|
||||
n_embd: Option<usize>,
|
||||
hidden_size: Option<usize>,
|
||||
|
@ -157,6 +160,7 @@ struct VisionConfig {}
|
|||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
head_dim: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
|
@ -166,6 +170,10 @@ struct Config {
|
|||
|
||||
impl From<RawConfig> for Config {
|
||||
fn from(other: RawConfig) -> Self {
|
||||
let max_position_embeddings = other
|
||||
.max_position_embeddings
|
||||
.or(other.max_seq_len)
|
||||
.or(other.n_positions);
|
||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||
let head_dim = other.head_dim.or_else(|| {
|
||||
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
||||
|
@ -187,6 +195,7 @@ impl From<RawConfig> for Config {
|
|||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
|
@ -479,7 +488,7 @@ struct Args {
|
|||
/// `1511` max_new_tokens.
|
||||
/// The larger this value, the larger amount each request will be in your RAM
|
||||
/// and the less effective batching can be.
|
||||
/// Default to min(max_allocatable, max_position_embeddings)
|
||||
/// Default to min(max_position_embeddings, 4096)
|
||||
#[clap(long, env)]
|
||||
max_total_tokens: Option<usize>,
|
||||
|
||||
|
@ -1667,6 +1676,28 @@ fn main() -> Result<(), LauncherError> {
|
|||
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = if let Some(config) = &config {
|
||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
if args.max_input_tokens.is_none()
|
||||
&& args.max_total_tokens.is_none()
|
||||
&& args.max_batch_prefill_tokens.is_none()
|
||||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
}
|
||||
} else {
|
||||
max_default
|
||||
}
|
||||
} else {
|
||||
max_default
|
||||
};
|
||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||
|
@ -1690,15 +1721,8 @@ fn main() -> Result<(), LauncherError> {
|
|||
match args.max_batch_prefill_tokens {
|
||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||
None => {
|
||||
// let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
||||
// max_batch_size * max_input_tokens
|
||||
// } else {
|
||||
// // Adding some edge in order to account for potential block_size alignement
|
||||
// // issue.
|
||||
// max_input_tokens + 50
|
||||
// } as u32;
|
||||
// TODO figure out hardware optimal value
|
||||
let value = 4096;
|
||||
let value = 4096.min(max_position_embeddings as u32);
|
||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
|
|
|
@ -1412,7 +1412,7 @@ class FlashCausalLM(Model):
|
|||
|
||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
_, _batch, _ = self.generate_token(batch)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
||||
|
@ -1442,14 +1442,14 @@ class FlashCausalLM(Model):
|
|||
max_total_tokens = num_blocks * BLOCK_SIZE
|
||||
|
||||
else:
|
||||
max_total_tokens = sum(len(input_ids) for input_ids in batch.input_ids)
|
||||
max_total_tokens = sum(batch.cache_lengths)
|
||||
max_input_tokens = (
|
||||
max_total_tokens - 1
|
||||
if max_input_tokens is None
|
||||
else max_input_tokens
|
||||
)
|
||||
|
||||
del batch
|
||||
del _batch, batch
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
|
|
|
@ -132,7 +132,13 @@ class Model(ABC):
|
|||
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
||||
) -> Tuple[Optional[int], int, int]:
|
||||
self.generate_token(batch)
|
||||
return None, 0, 0
|
||||
total = sum(len(i) for i in batch.input_ids)
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = total
|
||||
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
return None, max_input_tokens, max_total_tokens
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue