diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 35c8faae..4cb4ca59 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") class ResponseComparator(JSONSnapshotExtension): rtol = 0.2 + def serialize( self, data, @@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension): prefill_token.id == other.id and prefill_token.text == other.text and ( - math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol) + math.isclose( + prefill_token.logprob, other.logprob, rel_tol=self.rtol + ) if prefill_token.logprob is not None else prefill_token.logprob == other.logprob ) @@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator): # Needed for GPTQ with exllama which has serious numerical fluctuations. rtol = 0.75 + class LauncherHandle: def __init__(self, port: int): self.client = AsyncClient(f"http://localhost:{port}") @@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle): def response_snapshot(snapshot): return snapshot.use_extension(ResponseComparator) + @pytest.fixture def generous_response_snapshot(snapshot): return snapshot.use_extension(GenerousResponseComparator) @@ -219,7 +224,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, - dtype: Optional[str] = None + dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -282,7 +287,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, - dtype: Optional[str] = None + dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -335,7 +340,7 @@ def launcher(event_loop): ], volumes=volumes, ports={"80/tcp": port}, - shm_size="1G" + shm_size="1G", ) yield ContainerLauncherHandle(client, container.name, port) diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index 003409b0..a0ce0570 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): @pytest.mark.asyncio @pytest.mark.private async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): - responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4) + responses = await generate_load( + flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 + ) assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" - assert responses[0].generated_text == '\nDeep learning is a subset of machine learning' + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert ( + responses[0].generated_text == "\nDeep learning is a subset of machine learning" + ) assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index 7d21afd9..ace3328b 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho ) assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" assert responses[0].generated_text == ": Let n = 10 - 1" assert responses == response_snapshot diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index 5a81a4f0..7e1d3e11 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def idefics_handle(launcher): - with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle: + with launcher( + "HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16" + ) as handle: yield handle diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 1990ef8b..d9a33795 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids]) - assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts]) + assert all( + [ + token_id.item() == 10264 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == "Test" + for generation in generations + for token_text in generation.tokens.texts + ] + ) assert generations[0].request_id == 0 diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f105ce6f..8b45e781 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids]) - assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts]) + assert all( + [ + token_id.item() == 13 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == "." + for generation in generations + for token_text in generation.tokens.texts + ] + ) assert generations[0].request_id == 0 diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index d553067e..373867c7 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids]) - assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts]) + assert all( + [ + token_id.item() == 259 + for generation in generations + for token_id in generation.tokens.token_ids + ] + ) + assert all( + [ + token_text == " " + for generation in generations + for token_text in generation.tokens.texts + ] + ) assert generations[0].request_id == 0 diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index cb151173..1d67d7eb 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -77,12 +77,24 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value - if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}: + if dtype is not None and quantize not in { + None, + "bitsandbytes", + "bitsandbytes-nf4", + "bitsandbytes-fp4", + }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, ) @@ -140,12 +152,17 @@ def download_weights( try: import json - medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt") + + medusa_head = hf_hub_download( + model_id, revision=revision, filename="medusa_lm_head.pt" + ) if auto_convert: - medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors") + medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors") if not medusa_sf.exists(): utils.convert_files([Path(medusa_head)], [medusa_sf], []) - medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json") + medusa_config = hf_hub_download( + model_id, revision=revision, filename="config.json" + ) with open(medusa_config, "r") as f: config = json.load(f) @@ -153,10 +170,17 @@ def download_weights( revision = "main" try: utils.weight_files(model_id, revision, extension) - logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.") + logger.info( + f"Files for parent {model_id} are already present on the host. " + "Skipping download." + ) return # Local files not found - except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): + except ( + utils.LocalEntryNotFoundError, + FileNotFoundError, + utils.EntryNotFoundError, + ): pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0172d32c..f2ef55ce 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -88,7 +88,6 @@ if MIXTRAL: __all__.append(FlashMixtral) - def get_model( model_id: str, revision: Optional[str], @@ -157,7 +156,9 @@ def get_model( speculate_medusa = config_dict["medusa_num_heads"] if speculate is not None: if speculate > speculate_medusa: - raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match") + raise RuntimeError( + "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match" + ) else: set_speculate(speculate) else: @@ -249,7 +250,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa + use_medusa=use_medusa, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) @@ -313,7 +314,9 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") + raise NotImplementedError( + "Mixtral models requires flash attention v2, stk and megablocks" + ) if model_type == "opt": return OPTSharded( @@ -354,7 +357,7 @@ def get_model( raise ValueError("awq quantization is not supported for AutoModel") elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): raise ValueError("4bit quantization is not supported for AutoModel") - elif (quantize == "eetq"): + elif quantize == "eetq": raise ValueError("Eetq quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 8e8daad3..c3876023 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + prefix="transformer", ) if config.quantize == "gptq": weights._set_gptq_params(model_id) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c571a022..b771264b 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -510,7 +510,11 @@ class CausalLM(Model): load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes": + if ( + torch.cuda.is_available() + and torch.cuda.device_count() == 1 + and quantize != "bitsandbytes" + ): model = model.cuda() if tokenizer.pad_token_id is None: @@ -676,7 +680,10 @@ class CausalLM(Model): skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], ) else: prefill_tokens = None @@ -703,11 +710,11 @@ class CausalLM(Model): request.id, prefill_tokens, Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d06b87eb..3b424f80 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -34,9 +34,10 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelHead, get_linear, - FastRMSNorm + FastRMSNorm, ) + class LlamaConfig(PretrainedConfig): def __init__( self, @@ -202,7 +203,7 @@ class FlashLlamaAttention(torch.nn.Module): ) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( @@ -237,7 +238,7 @@ class FlashLlamaAttention(torch.nn.Module): input_lengths, max_s, ) - + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module): ) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) self.post_attention_layernorm = FastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 4e56b188..525bf6bc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA +from text_generation_server.utils.flash_attn import ( + attention, + HAS_FLASH_ATTN_V2_ROCM, + HAS_FLASH_ATTN_V2_CUDA, +) from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -35,7 +39,7 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelHead, get_linear, - FastRMSNorm + FastRMSNorm, ) @@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig): **kwargs, ) + def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 66753d5a..6f5edca2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA +from text_generation_server.utils.flash_attn import ( + HAS_FLASH_ATTN_V2_ROCM, + HAS_FLASH_ATTN_V2_CUDA, +) from text_generation_server.utils.layers import ( FastLinear, FastRMSNorm, @@ -59,28 +62,28 @@ class MixtralConfig(PretrainedConfig): model_type = "mixtral" def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - max_position_embeddings=4096 * 32, - initializer_range=0.02, - rms_norm_eps=1e-05, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - sliding_window=4096, - num_experts_per_tok=2, - num_local_experts=8, - **kwargs, + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + num_experts_per_tok=2, + num_local_experts=8, + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -166,16 +169,18 @@ def _load_experts(config, prefix, mat, weights): rank = weights.process_group.rank() assert ( - config.intermediate_size % world_size == 0 + config.intermediate_size % world_size == 0 ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" block_size = config.intermediate_size // world_size start = rank * block_size stop = (rank + 1) * block_size - tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), - dtype=weights.dtype, - device=weights.device) + tensor = torch.empty( + (config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) for i in range(config.num_local_experts): slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") @@ -184,16 +189,18 @@ def _load_experts(config, prefix, mat, weights): expert_slice = slice_[:, start:stop].t().contiguous() else: expert_slice = slice_[start:stop] - tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) return tensor class MixtralAttention(torch.nn.Module): def __init__( - self, - prefix: str, - config, - weights, + self, + prefix: str, + config, + weights, ): super().__init__() self.max_past = ( @@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module): device=weights.device, ) - self.softmax_scale = self.head_size ** -0.5 + self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -219,7 +226,7 @@ class MixtralAttention(torch.nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() + config.num_key_value_heads // weights.process_group.size() ) self.query_key_value = load_attention(config, prefix, weights) @@ -236,17 +243,17 @@ class MixtralAttention(torch.nn.Module): ).repeat_interleave(self.num_groups) def forward( - self, - hidden_states, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - prefill_cache_indices, + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module): # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, self.blocking, block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, self.blocking, block_rows, blocks_per_row + ) # For now, use meta init to save the device memory. data = torch.empty( @@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module): # position of each bin. # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, - self.blocking) + padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) # padded_tokens_per_expert => [128, O, 128, ...] # Cumulative selected experts per token @@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module): # Permute tokens and pad to prepare expert computation # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, - self.top_k) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) # Create the sparse matrix topology with torch.no_grad(): @@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module): # (top_k * sequence_length + padding, ffn_dim * n_experts) x = stk.Matrix( topo.size(), - self.act(stk.ops.sdd(x, self.w1, topo).data) * - stk.ops.sdd(x, self.w3, topo).data, + self.act(stk.ops.sdd(x, self.w1, topo).data) + * stk.ops.sdd(x, self.w3, topo).data, topo.row_indices, topo.column_indices, topo.offsets, @@ -537,7 +543,9 @@ class MixtralLayer(nn.Module): self.self_attn = MixtralAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) - self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) + self.block_sparse_moe = BlockSparseMoE( + f"{prefix}.block_sparse_moe", config, weights + ) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -549,18 +557,18 @@ class MixtralLayer(nn.Module): ) def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - input_lengths, - max_s, - prefill_cache_indices, + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -615,16 +623,16 @@ class MixtralModel(torch.nn.Module): self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -670,17 +678,17 @@ class FlashMixtralForCausalLM(torch.nn.Module): raise ValueError("max_past cannot be None") def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py index 4760ae6f..e323d365 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor): image = image_url_or_urls if image.startswith("http://") or image.startswith("https://"): - response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)) + response = requests.get( + image_url_or_urls, stream=True, headers=headers, timeout=(1, 5) + ) response.raise_for_status() content = response.content elif image.startswith("data:"): @@ -213,7 +215,7 @@ class IdeficsImageProcessor(BaseImageProcessor): image = Image.open(BytesIO(content)) # image.verify() except Exception: - raise ValueError(f"Could not load image from url {image_url_or_urls}") + raise ValueError(f"Could not load image from url {image_url_or_urls}") return image else: raise ValueError( diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 946f7683..555bf5af 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM: elif IS_ROCM_SYSTEM: from vllm import layernorm_ops + @dataclass class BaseModelOutputWithPastImage(BaseModelOutputWithPast): image_hidden_states: Optional[torch.FloatTensor] = None @@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module): return out else: - raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) # this was adapted from LlamaMLP @@ -613,8 +616,13 @@ class IdeficsAttention(nn.Module): query_shape = query_states.shape key_shape = key_states.shape - self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin) - + self.rotary_emb( + query_states.view(-1, *query_shape[2:]), + key_states.reshape(-1, *key_shape[2:]), + cos, + sin, + ) + query_states = query_states.view(query_shape) key_states = key_states.view(key_shape) diff --git a/server/text_generation_server/models/custom_modeling/idefics_processing.py b/server/text_generation_server/models/custom_modeling/idefics_processing.py index 98e43a27..beca864b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_processing.py @@ -112,6 +112,7 @@ def is_url(string): result = urlparse(string) return all([result.scheme, result.netloc]) + def is_image(string): """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately invalidated the url""" @@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin): image_objects = self.image_processor(image_objects, transform=transform) - text_encoding = self.tokenizer( text=full_text, add_special_tokens=False, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 79344ea1..14d30635 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,7 +11,7 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict -from text_generation_server.models import Model +from text_generation_server.models import Model from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( Batch, @@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch): input_length = len(tokenized_input) input_lengths.append(input_length) - - prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch): cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens + speculative_length) + max_length = max( + max_length, input_length + max_new_tokens + speculative_length + ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch): slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] - speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None + speculative_ids = ( + self.speculative_ids[indices] if self.speculative_ids is not None else None + ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch): total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks - speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 + speculative_length = ( + b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 + ) max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch): device=batches[0].next_token_chooser.device, ) - speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None + speculative_ids = ( + torch.cat([b.speculative_ids for b in batches], dim=0) + if batches[0].speculative_ids is not None + else None + ) # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: @@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, - speculative_ids=speculative_ids + speculative_ids=speculative_ids, ) def __del__(self): @@ -727,43 +735,54 @@ class FlashCausalLM(Model): def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward if batch.speculative_ids is not None: - input_ids=batch.input_ids - position_ids=batch.position_ids - cu_seqlen_prefill=batch.cu_seqlen_prefill - kv_cache=get_cache_manager().kv_cache - block_tables=batch.block_tables_tensor - slots=batch.slots[batch.slot_indices] - input_lengths=batch.input_lengths_tensor - max_s=batch.max_seqlen - lm_head_indices=batch.prefill_head_indices + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids - B, speculative_length = speculative_ids.shape + B, speculative_length = speculative_ids.shape new_length = speculative_length + 1 - new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) - new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) # Add Copy the block tables for all members - block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) max_s = max_s + speculative_length input_ids = new_input_ids position_ids = new_position_ids else: - input_ids=batch.input_ids - position_ids=batch.position_ids - cu_seqlen_prefill=batch.cu_seqlen_prefill - kv_cache=get_cache_manager().kv_cache - block_tables=batch.block_tables_tensor - slots=batch.slots[batch.slot_indices] - input_lengths=batch.input_lengths_tensor - max_s=batch.max_seqlen - lm_head_indices=batch.prefill_head_indices + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices return self.model.forward( input_ids=input_ids, @@ -808,20 +827,31 @@ class FlashCausalLM(Model): else: speculative_logits = None - if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) if speculative_logits is not None: speculative_logits = ( - speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits + speculative_logits[batch.prefill_next_token_indices] + if prefill_logprobs + else speculative_logits ) else: next_token_logits = out - next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits + ( + next_input_ids, + next_token_logprobs, + logprobs, + accepted_ids, + speculative_ids, + ) = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], + next_token_logits, + get_speculate(), + batch.speculative_ids, + speculative_logits, ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( @@ -851,11 +881,7 @@ class FlashCausalLM(Model): stopped = True # Zipped iterator - iterator = zip( - batch.input_lengths, - batch.all_input_ids, - accepted_ids - ) + iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a GPU <-> CPU sync @@ -863,11 +889,7 @@ class FlashCausalLM(Model): # For each member of the batch index = 0 - for i, ( - input_length, - all_input_ids, - n_accepted_ids - ) in enumerate(iterator): + for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length @@ -901,7 +923,6 @@ class FlashCausalLM(Model): cumulative_length += input_length - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids @@ -983,8 +1004,10 @@ class FlashCausalLM(Model): current_stopped = False stopped = stopped and current_stopped - _next_token_ids = next_token_ids[index: index+n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] index += n_accepted_ids # Shard generations @@ -1027,7 +1050,10 @@ class FlashCausalLM(Model): ) prefill_tokens = Tokens( - prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + is_special=[], ) else: prefill_tokens = None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 3a84b1b6..2415a245 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM): from text_generation_server.utils.medusa import MedusaModel from huggingface_hub import hf_hub_download import json - medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json") + + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) with open(medusa_config, "r") as f: config = json.load(f) - medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") - medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" - weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) lm_head = model.lm_head model.lm_head = MedusaModel(config, weights, lm_head) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 5ce37164..0fad5aa8 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch): # Parse batch for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) + zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate:] + tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch): cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens + speculative_length) + max_length = max( + max_length, input_length + max_new_tokens + speculative_length + ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -273,20 +275,20 @@ class FlashMistralBatch(FlashCausalLMBatch): blocks=blocks, max_blocks=max_blocks, prefill_cache_indices=prefill_cache_indices, - speculative_ids=None + speculative_ids=None, ) class BaseFlashMistral(FlashCausalLM): def __init__( - self, - config_cls, - model_cls, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, + self, + config_cls, + model_cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM): def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward if batch.speculative_ids is not None: - input_ids=batch.input_ids - position_ids=batch.position_ids - cu_seqlen_prefill=batch.cu_seqlen_prefill - kv_cache=get_cache_manager().kv_cache - block_tables=batch.block_tables_tensor - slots=batch.slots[batch.slot_indices] - input_lengths=batch.input_lengths_tensor - max_s=batch.max_seqlen - lm_head_indices=batch.prefill_head_indices + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids - B, speculative_length = speculative_ids.shape + B, speculative_length = speculative_ids.shape new_length = speculative_length + 1 - new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) - new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) # Add Copy the block tables for all members - block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) max_s = max_s + speculative_length input_ids = new_input_ids position_ids = new_position_ids else: - input_ids=batch.input_ids - position_ids=batch.position_ids - cu_seqlen_prefill=batch.cu_seqlen_prefill - kv_cache=get_cache_manager().kv_cache - block_tables=batch.block_tables_tensor - slots=batch.slots[batch.slot_indices] - input_lengths=batch.input_lengths_tensor - max_s=batch.max_seqlen - lm_head_indices=batch.prefill_head_indices + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -401,12 +414,12 @@ class BaseFlashMistral(FlashCausalLM): class FlashMistral(BaseFlashMistral): def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, ): super(FlashMistral, self).__init__( config_cls=MistralConfig, @@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral): revision=revision, quantize=quantize, dtype=dtype, - trust_remote_code=trust_remote_code + trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index c45ae50f..6f77a658 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -3,17 +3,20 @@ import torch from typing import Optional from text_generation_server.models.flash_mistral import BaseFlashMistral -from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM +from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + MixtralConfig, + FlashMixtralForCausalLM, +) class FlashMixtral(BaseFlashMistral): def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, ): super(FlashMixtral, self).__init__( config_cls=MixtralConfig, @@ -22,5 +25,5 @@ class FlashMixtral(BaseFlashMistral): revision=revision, quantize=quantize, dtype=dtype, - trust_remote_code=trust_remote_code + trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2f4bb139..86389ad2 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -792,7 +792,10 @@ class IdeficsCausalLM(Model): skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], ) else: prefill_tokens = None @@ -803,10 +806,10 @@ class IdeficsCausalLM(Model): request.id, prefill_tokens, Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, top_tokens, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 8552960d..dfb21dcb 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -56,7 +56,7 @@ class Model(ABC): dtype=str(self.dtype), device_type=self.device.type, window_size=self.sliding_window, - speculate=self.speculate + speculate=self.speculate, ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 279b5505..a85ef58e 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -736,7 +736,7 @@ class Seq2SeqLM(Model): [self.tokenizer.bos_token_id], [float("nan")], [self.tokenizer.bos_token], - [False] + [False], ) else: prefill_tokens = None @@ -763,10 +763,10 @@ class Seq2SeqLM(Model): request.id, prefill_tokens, Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], ), generated_text, top_tokens, diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 87c03d63..f85f27e5 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -66,7 +66,10 @@ class Tokens: def to_pb(self) -> generate_pb2.Tokens: return generate_pb2.Tokens( - ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special + ids=self.token_ids, + logprobs=self.logprobs, + texts=self.texts, + is_special=self.is_special, ) def __len__(self): diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index ebe066e3..75dba972 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -159,7 +159,13 @@ def serve( try: model = get_model( - model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, ) except Exception: logger.exception("Error when initializing model") @@ -207,5 +213,7 @@ def serve( await server.stop(0) asyncio.run( - serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code) + serve_inner( + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + ) ) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index aca95e11..3237df82 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -51,7 +51,9 @@ except ImportError as e: ) from e elif IS_ROCM_SYSTEM: for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): raise ImportError( f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" ) @@ -91,8 +93,10 @@ def attention( ) elif HAS_FLASH_ATTN_V2_ROCM: if window_size_left != -1: - raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).") - + raise ValueError( + f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + # RoCm flash API does not take the window_size_left and window_size_right arguments. return flash_attn_2_cuda.varlen_fwd( q, diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index 1945338b..f820f0d9 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -11,40 +11,44 @@ logger = getLogger(__name__) try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error('exllamav2_kernels not installed.') + logger.error("exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") + def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): """Matrix multiplication, returns x @ q4""" output_shape = x.shape[:-1] + (q4_width,) x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device) + output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) gemm_half_q_half(x, q_handle, output, force_cuda) return output.view(output_shape) + def ext_make_q_matrix(w: dict, temp_dq, key: str = None): """ - Create Q matrix + Create Q matrix """ # EXL2 - # won't work as the moment because the tensors are not the same. + # won't work as the moment because the tensors are not the same. if "q_weight" in w: w["q_scale_max"] /= 256 w["q_perm"] = w["q_perm"].short() w["q_invperm"] = w["q_invperm"].short() - return make_q_matrix(w["q_weight"], - w["q_perm"], - w["q_invperm"], - w["q_scale"], - w["q_scale_max"], - w["q_groups"], - none_tensor, - none_tensor, - none_tensor, - temp_dq) + return make_q_matrix( + w["q_weight"], + w["q_perm"], + w["q_invperm"], + w["q_scale"], + w["q_scale_max"], + w["q_groups"], + none_tensor, + none_tensor, + none_tensor, + temp_dq, + ) # GPTQ elif "qweight" in w: if w["scales"].dtype == torch.float: @@ -52,31 +56,40 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): # GPTQ with g_idx (act_order) if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): - w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device) + w["q_perm"] = torch.empty( + (w["qweight"].shape[0] * 8,), + dtype=torch.short, + device=w["qweight"].device, + ) w["q_invperm"] = torch.empty_like(w["q_perm"]) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return make_q_matrix(w["qweight"], - w["q_perm"], - w["q_invperm"], - none_tensor, - none_tensor, - none_tensor, - w["qzeros"], - w["scales"], - w["g_idx"].cpu(), - temp_dq) + return make_q_matrix( + w["qweight"], + w["q_perm"], + w["q_invperm"], + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + w["g_idx"].cpu(), + temp_dq, + ) # GPTQ without g_idx else: - return make_q_matrix(w["qweight"], - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - w["qzeros"], - w["scales"], - none_tensor, - temp_dq) + return make_q_matrix( + w["qweight"], + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + none_tensor, + temp_dq, + ) + DEVICE = None FIXED_BYTES = 0 @@ -106,14 +119,15 @@ class QuantLinear(nn.Module): super().__init__() if bits != 4: raise ValueError( - f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") + f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." + ) self.q_handle = None self.q_tensors = None self.bits = bits - self.maxq = 2 ** self.bits - 1 + self.maxq = 2**self.bits - 1 self.infeatures = qweight.shape[0] // self.bits * 32 self.outfeatures = qweight.shape[1] - self.padding = - self.outfeatures % 32 + self.padding = -self.outfeatures % 32 self.outfeatures = self.outfeatures + self.padding self.device = qweight.device @@ -128,9 +142,12 @@ class QuantLinear(nn.Module): outfeatures = self.outfeatures assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) assert infeatures % self.group_size == 0 - assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits) + assert qzeros.shape == ( + infeatures // self.group_size, + outfeatures // 32 * self.bits, + ) assert scales.shape == (infeatures // self.group_size, outfeatures) - assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}" + assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}" global FIXED_BYTES, LAYERS FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) @@ -140,33 +157,31 @@ class QuantLinear(nn.Module): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None self.q_tensors = { - "qweight":self.qweight, - "qzeros":self.qzeros, - "scales":self.scales, - "g_idx":self.g_idx + "qweight": self.qweight, + "qzeros": self.qzeros, + "scales": self.scales, + "g_idx": self.g_idx, } temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) - self.q_handle = ext_make_q_matrix( - self.q_tensors, temp_dq - ) - - def forward(self, x, force_cuda = False): + self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + + def forward(self, x, force_cuda=False): output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) if self.bias is not None: output.add_(self.bias) return output - + def temp_dq_size(self): return self.infeatures * self.outfeatures * 2 + 128 - + def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - + def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) - - + + class ExLlamaV2DeviceTensors: device_idx: int @@ -177,13 +192,16 @@ class ExLlamaV2DeviceTensors: def __init__(self, device, scratch_bytes): self.device = device self.scratch_bytes = scratch_bytes - + def prepare(self): - self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) + self.scratch = torch.empty( + (self.scratch_bytes // 2,), dtype=torch.half, device=self.device + ) def get_scratch_slice(self, size_bytes): - if self.scratch is None: self.prepare() + if self.scratch is None: + self.prepare() size_bytes = ((size_bytes + 127) // 128) * 128 size_half = size_bytes // 2 diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d533016d..77e2fdb6 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -35,7 +35,9 @@ HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: - logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding") + logger.warning( + "Disabling exllama v2 and using v1 instead because there are issues when sharding" + ) V2 = False if os.getenv("DISABLE_EXLLAMA") == "True": @@ -43,17 +45,19 @@ if os.getenv("DISABLE_EXLLAMA") == "True": elif CAN_EXLLAMA: try: if V2: - from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, - ) + from text_generation_server.utils.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) HAS_EXLLAMA = "2" else: - from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, - ) + from text_generation_server.utils.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) HAS_EXLLAMA = "1" @@ -114,7 +118,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st @classmethod def load_conv2d_no_bias( - cls, prefix, weights, in_channels, out_channels, kernel_size, stride + cls, prefix, weights, in_channels, out_channels, kernel_size, stride ): weight = weights.get_tensor(f"{prefix}.weight") with init_empty_weights(): @@ -138,9 +142,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias class FastLinear(nn.Module): def __init__( - self, - weight, - bias, + self, + weight, + bias, ) -> None: super().__init__() self.weight = nn.Parameter(weight) @@ -164,9 +168,9 @@ class FastLinear(nn.Module): class EETQLinear(nn.Module): def __init__( - self, - weight, - bias, + self, + weight, + bias, ) -> None: super().__init__() device = weight.device @@ -185,13 +189,13 @@ class EETQLinear(nn.Module): class Linear8bitLt(nn.Module): def __init__( - self, - weight, - bias, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, ): super().__init__() assert ( @@ -325,7 +329,9 @@ def get_linear(weight, bias, quantize): ) if use_exllama: - linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + linear = ExllamaQuantLinear( + qweight, qzeros, scales, g_idx, bias, bits, groupsize + ) else: linear = QuantLinear( qweight, @@ -533,7 +539,6 @@ try: else: dropout_layer_norm = None - class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: @@ -569,7 +574,6 @@ try: return normed_hidden_states, residual - class FastRMSNorm(nn.Module): def __init__(self, weight: torch.Tensor, eps: float): super().__init__() @@ -601,7 +605,11 @@ try: return self.weight * hidden_states, residual elif IS_CUDA_SYSTEM: # faster post attention rms norm - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + ( + normed_hidden_states, + res, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, self.weight, @@ -638,7 +646,8 @@ try: return out, residual else: raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) except ImportError: pass @@ -650,14 +659,12 @@ try: elif IS_ROCM_SYSTEM: from vllm import pos_encoding_ops - def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) return inv_freq - def _get_rope_config(config): if os.getenv("ROPE_SCALING", None) is not None: rope_scaling = { @@ -667,7 +674,6 @@ try: return rope_scaling return getattr(config, "rope_scaling", None) - class PositionRotaryEmbedding(nn.Module): def __init__(self, inv_freq, scaling_factor): super().__init__() @@ -680,17 +686,23 @@ try: self.scaling_factor = scaling_factor self.dynamic_args = None - def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): # Such controlflows may add some overhead. if IS_CUDA_SYSTEM: rotary_dim = cos.shape[-1] q1 = query[..., :rotary_dim] - q2 = query[..., rotary_dim: 2 * rotary_dim] + q2 = query[..., rotary_dim : 2 * rotary_dim] rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) k1 = key[..., :rotary_dim] - k2 = key[..., rotary_dim: 2 * rotary_dim] + k2 = key[..., rotary_dim : 2 * rotary_dim] rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif IS_ROCM_SYSTEM: @@ -700,17 +712,11 @@ try: head_size = query.shape[-1] # Inplace operation, updating query and key. - pos_encoding_ops.rotary_embedding( - query, - key, - head_size, - cos, - sin, - True - ) + pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) else: raise ValueError( - "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) @classmethod def static(cls, config, dim, base, device): @@ -732,15 +738,16 @@ try: elif rope_scaling["type"] == "yarn": return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], - max_position_embeddings=rope_scaling["original_max_position_embeddings"], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, - beta_slow=1 - + beta_slow=1, ) else: raise NotImplementedError( @@ -773,15 +780,16 @@ try: elif rope_scaling["type"] == "yarn": return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], - max_position_embeddings=rope_scaling["original_max_position_embeddings"], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, - beta_slow=1 - + beta_slow=1, ) else: raise NotImplementedError( @@ -793,9 +801,9 @@ try: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) @@ -809,7 +817,7 @@ try: self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin( - self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype ): """ Return cos and sin for the asked position ids @@ -827,7 +835,6 @@ try: # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) - class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) @@ -840,14 +847,14 @@ try: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings: newbase = self.base * ( - (self.scaling_factor * seqlen / self.max_position_embeddings) - - (self.scaling_factor - 1) + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) self.inv_freq = _create_inv_freq( self.dim, newbase, self.inv_freq.device @@ -861,24 +868,28 @@ try: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - # Inverse dim formula to find dim based on number of rotations import math - - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - + def find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 + ): + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) # Find dim range bounds based on rotations - def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(find_correction_dim( - low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim( - high_rot, dim, base, max_position_embeddings)) + def find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 + ): + low = math.floor( + find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) return max(low, 0), min(high, dim - 1) # Clamp values just in case - def linear_ramp_mask(min, max, dim): if min == max: max += 0.001 # Prevent singularity @@ -887,16 +898,25 @@ try: ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func - def get_mscale(scale=1): if scale <= 1: return 1.0 return 0.1 * math.log(scale) + 1.0 - class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): - def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor, - attn_factor, beta_fast, beta_slow): + def __init__( + self, + dim, + max_position_embeddings, + base, + device, + scaling_factor, + *, + extrapolation_factor, + attn_factor, + beta_fast, + beta_slow, + ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) self.dim = dim @@ -906,16 +926,17 @@ try: self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow - self.mscale = float(get_mscale( - self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings: inv_freq_extrapolation = _create_inv_freq( @@ -923,15 +944,26 @@ try: ) freqs = 1.0 / inv_freq_extrapolation inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) - low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, - self.max_position_embeddings) - inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to( - device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = ( + 1 + - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) self.inv_freq = inv_freq - self.mscale = float(get_mscale( - self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py index 029de122..634119cb 100644 --- a/server/text_generation_server/utils/medusa.py +++ b/server/text_generation_server/utils/medusa.py @@ -2,6 +2,7 @@ import torch from dataclasses import dataclass from text_generation_server.utils.layers import TensorParallelHead, FastLinear + @dataclass class Output: logits: torch.FloatTensor = None @@ -11,7 +12,9 @@ class Output: class ResBlock(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() - self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True) + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) self.act = torch.nn.SiLU() def forward(self, x): @@ -19,15 +22,13 @@ class ResBlock(torch.nn.Module): class MedusaModel(torch.nn.Module): - def __init__( - self, - config, - weights, - lm_head - ): + def __init__(self, config, weights, lm_head): super().__init__() self.heads = torch.nn.ModuleList( - [MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] + [ + MedusaHead(config, prefix=f"{i}", weights=weights) + for i in range(config["medusa_num_heads"]) + ] ) self.lm_head = lm_head @@ -40,9 +41,16 @@ class MedusaModel(torch.nn.Module): class MedusaHead(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__() - self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])]) + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(config["medusa_num_layers"]) + ] + ) n = len(self.blocks) - self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) def forward(self, x): for block in self.blocks: diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 57a59599..4b12744c 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -7,23 +7,26 @@ from vllm import attention_ops _PARTITION_SIZE = 512 -def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - slots: torch.Tensor): - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) def attention( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -45,9 +48,7 @@ def attention( # value_cache => [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (max_s + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index d37e8940..45e23320 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -38,7 +38,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): os.makedirs(model_id, exist_ok=True) cache_dir = model_id logger.info(f"Saving the newly created merged model to {cache_dir}") - tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained( + base_model_id, trust_remote_code=trust_remote_code + ) model.save_pretrained(cache_dir, safe_serialization=True) model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir) diff --git a/server/text_generation_server/utils/speculate.py b/server/text_generation_server/utils/speculate.py index 38a91972..a1b37a34 100644 --- a/server/text_generation_server/utils/speculate.py +++ b/server/text_generation_server/utils/speculate.py @@ -1,12 +1,11 @@ - SPECULATE = None + def get_speculate() -> int: global SPECULATE return SPECULATE + def set_speculate(speculate: int): global SPECULATE SPECULATE = speculate - - diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index a34c5afc..0d208104 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -16,6 +16,7 @@ from text_generation_server.utils.logits_process import ( from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor + class NextTokenChooser: def __init__( self, @@ -145,21 +146,31 @@ class StoppingCriteria: pb.ignore_eos_token, ) -def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool): + +def create_n_gram_speculation( + input_ids: torch.Tensor, + next_ids: torch.Tensor, + accepted_ids: torch.Tensor, + speculate: int, + verbose: bool, +): # Very trivial approach, find first match in the string. # This is much less refined than actual n-gram but seems to work # relatively OK in grounded mode and is by far much faster with # much less worst case complexity as everything happens on device. B = accepted_ids.shape[0] device = input_ids.device - seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] + seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1] indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 - all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) + all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange( + speculate, device=device + ) all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) speculative_ids = input_ids.gather(dim=-1, index=all_indices) return speculative_ids + class HeterogeneousNextTokenChooser: def __init__( self, @@ -228,7 +239,15 @@ class HeterogeneousNextTokenChooser: self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): + def __call__( + self, + input_ids: torch.Tensor, + scores: torch.Tensor, + speculate: int, + speculated_ids: Optional[torch.Tensor] = None, + speculative_scores: Optional[torch.Tensor] = None, + verbose=False, + ): if speculated_ids is not None: B = scores.shape[0] // (speculated_ids.shape[1] + 1) S = speculated_ids.shape[1] + 1 @@ -249,12 +268,11 @@ class HeterogeneousNextTokenChooser: for warper in self.warpers: _scores = warper(input_ids, _scores) - _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids - next_ids = next_ids.view(B*S) - scores = scores.view( B* S, -1) + next_ids = next_ids.view(B * S) + scores = scores.view(B * S, -1) if speculated_ids is not None: accepted_ids = [] @@ -262,7 +280,7 @@ class HeterogeneousNextTokenChooser: S = speculated_ids.shape[1] + 1 indices = [] for i in range(B): - _next_ids = next_ids[i*S: (i + 1)*S] + _next_ids = next_ids[i * S : (i + 1) * S] _speculated_ids = speculated_ids[i] validate_speculative = _next_ids[:-1] == _speculated_ids index = i * S @@ -278,7 +296,9 @@ class HeterogeneousNextTokenChooser: break accepted_ids.append(accepted) - accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) + accepted_ids = torch.tensor( + accepted_ids, device=input_ids.device, dtype=input_ids.dtype + ) next_ids = next_ids[indices] scores = scores[indices] indices = torch.arange(B, device=input_ids.device) * S @@ -296,7 +316,9 @@ class HeterogeneousNextTokenChooser: speculative_ids = Greedy()(speculative_scores) else: # n-gram - speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) + speculative_ids = create_n_gram_speculation( + input_ids, next_ids, accepted_ids, speculate, verbose + ) else: speculative_ids = None diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index f3344988..802c1a90 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -16,7 +16,7 @@ class Weights: dtype, process_group, aliases: Optional[Dict[str, List[str]]] = None, - prefix: Optional[str] = None + prefix: Optional[str] = None, ): routing = {} for filename in filenames: @@ -213,7 +213,8 @@ class Weights: bits, groupsize = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA - use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq" + + use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -283,7 +284,7 @@ class Weights: if use_exllama: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0) + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = g_idx - g_idx[0] else: # The triton kernel reorders the scales/zero points instead of the weight/activation. diff --git a/update_doc.py b/update_doc.py index 6206e211..6127418c 100644 --- a/update_doc.py +++ b/update_doc.py @@ -21,14 +21,14 @@ def main(): block = [] for line in lines: if line.startswith(" -") or line.startswith(" -"): - rendered_block = '\n'.join(block) + rendered_block = "\n".join(block) if header: final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" else: final_doc += f"```shell\n{rendered_block}\n```\n" block = [] tokens = line.split("<") - if len(tokens)>1: + if len(tokens) > 1: header = tokens[-1][:-1] else: header = line.split("--")[-1] @@ -36,7 +36,7 @@ def main(): block.append(line) - rendered_block = '\n'.join(block) + rendered_block = "\n".join(block) final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" block = []