diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f2f5a99b..d74fca64 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -21,10 +21,28 @@ use tracing_subscriber::EnvFilter; mod env_runtime; +#[derive(Deserialize)] +struct RawConfig { + max_position_embeddings: Option, + n_positions: Option, + max_seq_len: Option, +} + #[derive(Deserialize)] struct Config { max_position_embeddings: Option, - max_seq_len: Option, +} + +impl From for Config { + fn from(other: RawConfig) -> Self { + let max_position_embeddings = other + .max_position_embeddings + .or(other.max_seq_len) + .or(other.n_positions); + Config { + max_position_embeddings, + } + } } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -1309,33 +1327,30 @@ fn main() -> Result<(), LauncherError> { }; let content = std::fs::read_to_string(filename)?; - let config: Config = serde_json::from_str(&content)?; + let config: RawConfig = serde_json::from_str(&content)?; + let config: Config = config.into(); // Quantization usually means you're even more RAM constrained. let max_default = 4096; - let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { - (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - max_default - } else { - max_position_embeddings + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); } + Ok(max_default) + } else { + Ok(max_position_embeddings) } - _ => { - return Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))); - } - }; - Ok(max_position_embeddings) + } else { + Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))) + } }; let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index b319ab5d..d4a325a9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -472,14 +472,26 @@ def get_model( ) elif model_type == GPT2: if FLASH_ATTENTION: - return FlashGPT2( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) + try: + return FlashGPT2( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + except RuntimeError as e: + # Lots of legacy models with various weight names. + logger.warning(f"Couldn't load flash gpt2 variant: {e}") + return CausalLM( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e549b7cb..37c46032 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -14,13 +14,21 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model -from text_generation_server.models.pali_gemma import PaliGemmaBatch -from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, -) + +try: + from text_generation_server.models.pali_gemma import PaliGemmaBatch + from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + ) + from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + + VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch} +except (ImportError, NotImplementedError): + # These imports can fail on CPU/Non flash. + VLM_BATCH_TYPES = set() + from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.globals import set_model_id @@ -96,11 +104,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): except ImportError: pass - if self.model.batch_type in { - IdeficsCausalLMBatch, - VlmCausalLMBatch, - PaliGemmaBatch, - }: # Hack, i would rather use kwargs in the `from_pb` call + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch, self.model.tokenizer, @@ -121,11 +127,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): start = time.time_ns() - if self.model.batch_type in { - IdeficsCausalLMBatch, - VlmCausalLMBatch, - PaliGemmaBatch, - }: # Hack, i would rather use kwargs in the `from_pb` call + if ( + self.model.batch_type in VLM_BATCH_TYPES + ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch, self.model.tokenizer,